@elaraai/east-py-datascience 0.0.2-beta.32 → 0.0.2-beta.34
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
|
@@ -13,6 +13,33 @@
|
|
|
13
13
|
*/
|
|
14
14
|
import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType, NullType, BooleanType, FunctionType, StringType } from "@elaraai/east";
|
|
15
15
|
export { VectorType, MatrixType } from "../types.js";
|
|
16
|
+
/**
|
|
17
|
+
* Return embedding mode for Decision Transformer.
|
|
18
|
+
*/
|
|
19
|
+
export declare const ReturnEmbeddingType: VariantType<{
|
|
20
|
+
/** Single return value for entire sequence */
|
|
21
|
+
global: NullType;
|
|
22
|
+
/** Return-to-go at each timestep */
|
|
23
|
+
per_timestep: NullType;
|
|
24
|
+
}>;
|
|
25
|
+
/**
|
|
26
|
+
* Per-head output configuration for multi_head_mixed.
|
|
27
|
+
*/
|
|
28
|
+
export declare const HeadConfigType: StructType<{
|
|
29
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
30
|
+
head_type: VariantType<{
|
|
31
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
32
|
+
binary: NullType;
|
|
33
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
34
|
+
multiclass: StructType<{
|
|
35
|
+
n_classes: IntegerType;
|
|
36
|
+
}>;
|
|
37
|
+
}>;
|
|
38
|
+
/** Optional class weights for this head */
|
|
39
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
40
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
41
|
+
conditional_on: OptionType<IntegerType>;
|
|
42
|
+
}>;
|
|
16
43
|
/**
|
|
17
44
|
* Lightning output mode - determines loss function and output activation.
|
|
18
45
|
*/
|
|
@@ -40,6 +67,31 @@ export declare const LightningOutputType: VariantType<{
|
|
|
40
67
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
41
68
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
42
69
|
}>;
|
|
70
|
+
/**
|
|
71
|
+
* Mixed output types per head.
|
|
72
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
73
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
74
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
75
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
76
|
+
*/
|
|
77
|
+
multi_head_mixed: StructType<{
|
|
78
|
+
/** Array of head configurations */
|
|
79
|
+
heads: ArrayType<StructType<{
|
|
80
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
81
|
+
head_type: VariantType<{
|
|
82
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
83
|
+
binary: NullType;
|
|
84
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
85
|
+
multiclass: StructType<{
|
|
86
|
+
n_classes: IntegerType;
|
|
87
|
+
}>;
|
|
88
|
+
}>;
|
|
89
|
+
/** Optional class weights for this head */
|
|
90
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
91
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
92
|
+
conditional_on: OptionType<IntegerType>;
|
|
93
|
+
}>>;
|
|
94
|
+
}>;
|
|
43
95
|
}>;
|
|
44
96
|
/**
|
|
45
97
|
* Cell type for sequential architectures.
|
|
@@ -122,6 +174,36 @@ export declare const LightningArchitectureType: VariantType<{
|
|
|
122
174
|
/** Optional condition dimension for conditional generation */
|
|
123
175
|
condition_dim: OptionType<IntegerType>;
|
|
124
176
|
}>;
|
|
177
|
+
/**
|
|
178
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
179
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
180
|
+
* Predicts actions conditioned on desired return and state history.
|
|
181
|
+
*/
|
|
182
|
+
decision_transformer: StructType<{
|
|
183
|
+
/** Sequence length (timesteps) */
|
|
184
|
+
sequence_length: IntegerType;
|
|
185
|
+
/** State dimension per timestep */
|
|
186
|
+
state_dim: IntegerType;
|
|
187
|
+
/** Action dimension per timestep */
|
|
188
|
+
action_dim: IntegerType;
|
|
189
|
+
/** Model dimension (transformer hidden size) */
|
|
190
|
+
d_model: IntegerType;
|
|
191
|
+
/** Number of attention heads */
|
|
192
|
+
n_attention_heads: IntegerType;
|
|
193
|
+
/** Number of transformer layers */
|
|
194
|
+
n_layers: IntegerType;
|
|
195
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
196
|
+
d_ff: OptionType<IntegerType>;
|
|
197
|
+
/** Dropout rate */
|
|
198
|
+
dropout: OptionType<FloatType>;
|
|
199
|
+
/** Whether return is per-timestep or global */
|
|
200
|
+
return_embedding: VariantType<{
|
|
201
|
+
/** Single return value for entire sequence */
|
|
202
|
+
global: NullType;
|
|
203
|
+
/** Return-to-go at each timestep */
|
|
204
|
+
per_timestep: NullType;
|
|
205
|
+
}>;
|
|
206
|
+
}>;
|
|
125
207
|
}>;
|
|
126
208
|
/**
|
|
127
209
|
* Epoch callback function type: (epoch, train_loss, val_loss) -> void
|
|
@@ -203,6 +285,36 @@ export declare const LightningConfigType: StructType<{
|
|
|
203
285
|
/** Optional condition dimension for conditional generation */
|
|
204
286
|
condition_dim: OptionType<IntegerType>;
|
|
205
287
|
}>;
|
|
288
|
+
/**
|
|
289
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
290
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
291
|
+
* Predicts actions conditioned on desired return and state history.
|
|
292
|
+
*/
|
|
293
|
+
decision_transformer: StructType<{
|
|
294
|
+
/** Sequence length (timesteps) */
|
|
295
|
+
sequence_length: IntegerType;
|
|
296
|
+
/** State dimension per timestep */
|
|
297
|
+
state_dim: IntegerType;
|
|
298
|
+
/** Action dimension per timestep */
|
|
299
|
+
action_dim: IntegerType;
|
|
300
|
+
/** Model dimension (transformer hidden size) */
|
|
301
|
+
d_model: IntegerType;
|
|
302
|
+
/** Number of attention heads */
|
|
303
|
+
n_attention_heads: IntegerType;
|
|
304
|
+
/** Number of transformer layers */
|
|
305
|
+
n_layers: IntegerType;
|
|
306
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
307
|
+
d_ff: OptionType<IntegerType>;
|
|
308
|
+
/** Dropout rate */
|
|
309
|
+
dropout: OptionType<FloatType>;
|
|
310
|
+
/** Whether return is per-timestep or global */
|
|
311
|
+
return_embedding: VariantType<{
|
|
312
|
+
/** Single return value for entire sequence */
|
|
313
|
+
global: NullType;
|
|
314
|
+
/** Return-to-go at each timestep */
|
|
315
|
+
per_timestep: NullType;
|
|
316
|
+
}>;
|
|
317
|
+
}>;
|
|
206
318
|
}>;
|
|
207
319
|
/** Output mode (determines loss function) */
|
|
208
320
|
output: VariantType<{
|
|
@@ -229,6 +341,31 @@ export declare const LightningConfigType: StructType<{
|
|
|
229
341
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
230
342
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
231
343
|
}>;
|
|
344
|
+
/**
|
|
345
|
+
* Mixed output types per head.
|
|
346
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
347
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
348
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
349
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
350
|
+
*/
|
|
351
|
+
multi_head_mixed: StructType<{
|
|
352
|
+
/** Array of head configurations */
|
|
353
|
+
heads: ArrayType<StructType<{
|
|
354
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
355
|
+
head_type: VariantType<{
|
|
356
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
357
|
+
binary: NullType;
|
|
358
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
359
|
+
multiclass: StructType<{
|
|
360
|
+
n_classes: IntegerType;
|
|
361
|
+
}>;
|
|
362
|
+
}>;
|
|
363
|
+
/** Optional class weights for this head */
|
|
364
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
365
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
366
|
+
conditional_on: OptionType<IntegerType>;
|
|
367
|
+
}>>;
|
|
368
|
+
}>;
|
|
232
369
|
}>;
|
|
233
370
|
/** Learning rate (default: 1e-3) */
|
|
234
371
|
learning_rate: OptionType<FloatType>;
|
|
@@ -401,6 +538,36 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
|
|
|
401
538
|
/** Optional condition dimension for conditional generation */
|
|
402
539
|
condition_dim: OptionType<IntegerType>;
|
|
403
540
|
}>;
|
|
541
|
+
/**
|
|
542
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
543
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
544
|
+
* Predicts actions conditioned on desired return and state history.
|
|
545
|
+
*/
|
|
546
|
+
decision_transformer: StructType<{
|
|
547
|
+
/** Sequence length (timesteps) */
|
|
548
|
+
sequence_length: IntegerType;
|
|
549
|
+
/** State dimension per timestep */
|
|
550
|
+
state_dim: IntegerType;
|
|
551
|
+
/** Action dimension per timestep */
|
|
552
|
+
action_dim: IntegerType;
|
|
553
|
+
/** Model dimension (transformer hidden size) */
|
|
554
|
+
d_model: IntegerType;
|
|
555
|
+
/** Number of attention heads */
|
|
556
|
+
n_attention_heads: IntegerType;
|
|
557
|
+
/** Number of transformer layers */
|
|
558
|
+
n_layers: IntegerType;
|
|
559
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
560
|
+
d_ff: OptionType<IntegerType>;
|
|
561
|
+
/** Dropout rate */
|
|
562
|
+
dropout: OptionType<FloatType>;
|
|
563
|
+
/** Whether return is per-timestep or global */
|
|
564
|
+
return_embedding: VariantType<{
|
|
565
|
+
/** Single return value for entire sequence */
|
|
566
|
+
global: NullType;
|
|
567
|
+
/** Return-to-go at each timestep */
|
|
568
|
+
per_timestep: NullType;
|
|
569
|
+
}>;
|
|
570
|
+
}>;
|
|
404
571
|
}>;
|
|
405
572
|
/** Output mode (determines loss function) */
|
|
406
573
|
output: VariantType<{
|
|
@@ -427,6 +594,31 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
|
|
|
427
594
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
428
595
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
429
596
|
}>;
|
|
597
|
+
/**
|
|
598
|
+
* Mixed output types per head.
|
|
599
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
600
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
601
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
602
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
603
|
+
*/
|
|
604
|
+
multi_head_mixed: StructType<{
|
|
605
|
+
/** Array of head configurations */
|
|
606
|
+
heads: ArrayType<StructType<{
|
|
607
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
608
|
+
head_type: VariantType<{
|
|
609
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
610
|
+
binary: NullType;
|
|
611
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
612
|
+
multiclass: StructType<{
|
|
613
|
+
n_classes: IntegerType;
|
|
614
|
+
}>;
|
|
615
|
+
}>;
|
|
616
|
+
/** Optional class weights for this head */
|
|
617
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
618
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
619
|
+
conditional_on: OptionType<IntegerType>;
|
|
620
|
+
}>>;
|
|
621
|
+
}>;
|
|
430
622
|
}>;
|
|
431
623
|
/** Learning rate (default: 1e-3) */
|
|
432
624
|
learning_rate: OptionType<FloatType>;
|
|
@@ -624,6 +816,305 @@ export declare const lightning_generate_sequence: import("@elaraai/east").Platfo
|
|
|
624
816
|
/** If true, return probabilities. If false, return samples. */
|
|
625
817
|
return_probs: BooleanType;
|
|
626
818
|
}>], ArrayType<ArrayType<FloatType>>>;
|
|
819
|
+
/**
|
|
820
|
+
* Configuration for Decision Transformer trajectory generation.
|
|
821
|
+
*/
|
|
822
|
+
export declare const TrajectoryGenerateConfigType: StructType<{
|
|
823
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
824
|
+
temperature: FloatType;
|
|
825
|
+
/** Whether to return probabilities or samples */
|
|
826
|
+
return_probs: BooleanType;
|
|
827
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
828
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
829
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
830
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
831
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
832
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
833
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
834
|
+
head_type: VariantType<{
|
|
835
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
836
|
+
binary: NullType;
|
|
837
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
838
|
+
multiclass: StructType<{
|
|
839
|
+
n_classes: IntegerType;
|
|
840
|
+
}>;
|
|
841
|
+
}>;
|
|
842
|
+
/** Optional class weights for this head */
|
|
843
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
844
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
845
|
+
conditional_on: OptionType<IntegerType>;
|
|
846
|
+
}>>>;
|
|
847
|
+
/** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
|
|
848
|
+
action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
849
|
+
/** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
|
|
850
|
+
start_timestep: OptionType<IntegerType>;
|
|
851
|
+
}>;
|
|
852
|
+
/**
|
|
853
|
+
* Train with trajectory data for return-conditioned sequence generation.
|
|
854
|
+
*
|
|
855
|
+
* Use with decision_transformer architecture.
|
|
856
|
+
*
|
|
857
|
+
* @param returns - Return per sample (n_samples,) - actual outcome achieved
|
|
858
|
+
* @param states - State matrices: n_samples × (seq_len, state_dim)
|
|
859
|
+
* @param actions - Action matrices: n_samples × (seq_len, action_dim)
|
|
860
|
+
* @param masks - Temporal masks: n_samples × (seq_len,) - valid timesteps
|
|
861
|
+
* @param config - Training configuration with decision_transformer architecture
|
|
862
|
+
* @returns Training result with model blob and metrics
|
|
863
|
+
*/
|
|
864
|
+
export declare const lightning_train_trajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
|
|
865
|
+
/** Model architecture */
|
|
866
|
+
architecture: VariantType<{
|
|
867
|
+
/** Simple MLP: input → hidden → output */
|
|
868
|
+
mlp: StructType<{
|
|
869
|
+
/** Hidden layer sizes */
|
|
870
|
+
hidden_layers: ArrayType<IntegerType>;
|
|
871
|
+
}>;
|
|
872
|
+
/** Autoencoder: input → encoder → latent → decoder → output */
|
|
873
|
+
autoencoder: StructType<{
|
|
874
|
+
/** Encoder hidden layer sizes */
|
|
875
|
+
encoder_layers: ArrayType<IntegerType>;
|
|
876
|
+
/** Latent dimension (bottleneck) */
|
|
877
|
+
latent_dim: IntegerType;
|
|
878
|
+
/** Decoder hidden layer sizes */
|
|
879
|
+
decoder_layers: ArrayType<IntegerType>;
|
|
880
|
+
}>;
|
|
881
|
+
/** Conv1D: 1D convolutional autoencoder for temporal patterns */
|
|
882
|
+
conv1d: StructType<{
|
|
883
|
+
/** Number of channels (e.g., additive types) */
|
|
884
|
+
n_channels: IntegerType;
|
|
885
|
+
/** Sequence length (e.g., days) */
|
|
886
|
+
sequence_length: IntegerType;
|
|
887
|
+
/** Conv layer channel sizes */
|
|
888
|
+
conv_channels: ArrayType<IntegerType>;
|
|
889
|
+
/** Kernel size for convolutions (must be odd) */
|
|
890
|
+
kernel_size: IntegerType;
|
|
891
|
+
/** Latent dimension after flattening */
|
|
892
|
+
latent_dim: IntegerType;
|
|
893
|
+
/** Optional condition dimension for conditional generation */
|
|
894
|
+
condition_dim: OptionType<IntegerType>;
|
|
895
|
+
}>;
|
|
896
|
+
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
897
|
+
sequential: StructType<{
|
|
898
|
+
/** Number of channels (e.g., additive types) */
|
|
899
|
+
n_channels: IntegerType;
|
|
900
|
+
/** Sequence length (e.g., days) */
|
|
901
|
+
sequence_length: IntegerType;
|
|
902
|
+
/** RNN hidden size */
|
|
903
|
+
hidden_size: IntegerType;
|
|
904
|
+
/** Number of RNN layers */
|
|
905
|
+
n_layers: IntegerType;
|
|
906
|
+
/** Cell type: lstm or gru */
|
|
907
|
+
cell_type: VariantType<{
|
|
908
|
+
lstm: NullType;
|
|
909
|
+
gru: NullType;
|
|
910
|
+
}>;
|
|
911
|
+
/** Latent dimension (from final hidden state) */
|
|
912
|
+
latent_dim: IntegerType;
|
|
913
|
+
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
914
|
+
bidirectional: BooleanType;
|
|
915
|
+
/** Optional condition dimension for conditional generation */
|
|
916
|
+
condition_dim: OptionType<IntegerType>;
|
|
917
|
+
}>;
|
|
918
|
+
/** Transformer: attention-based autoencoder for complex patterns */
|
|
919
|
+
transformer: StructType<{
|
|
920
|
+
/** Number of channels (e.g., additive types) */
|
|
921
|
+
n_channels: IntegerType;
|
|
922
|
+
/** Sequence length (e.g., days) */
|
|
923
|
+
sequence_length: IntegerType;
|
|
924
|
+
/** Model dimension */
|
|
925
|
+
d_model: IntegerType;
|
|
926
|
+
/** Number of attention heads (must divide d_model evenly) */
|
|
927
|
+
n_attention_heads: IntegerType;
|
|
928
|
+
/** Number of transformer layers */
|
|
929
|
+
n_layers: IntegerType;
|
|
930
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
931
|
+
d_ff: OptionType<IntegerType>;
|
|
932
|
+
/** Latent dimension (mean pooled output) */
|
|
933
|
+
latent_dim: IntegerType;
|
|
934
|
+
/** Optional condition dimension for conditional generation */
|
|
935
|
+
condition_dim: OptionType<IntegerType>;
|
|
936
|
+
}>;
|
|
937
|
+
/**
|
|
938
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
939
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
940
|
+
* Predicts actions conditioned on desired return and state history.
|
|
941
|
+
*/
|
|
942
|
+
decision_transformer: StructType<{
|
|
943
|
+
/** Sequence length (timesteps) */
|
|
944
|
+
sequence_length: IntegerType;
|
|
945
|
+
/** State dimension per timestep */
|
|
946
|
+
state_dim: IntegerType;
|
|
947
|
+
/** Action dimension per timestep */
|
|
948
|
+
action_dim: IntegerType;
|
|
949
|
+
/** Model dimension (transformer hidden size) */
|
|
950
|
+
d_model: IntegerType;
|
|
951
|
+
/** Number of attention heads */
|
|
952
|
+
n_attention_heads: IntegerType;
|
|
953
|
+
/** Number of transformer layers */
|
|
954
|
+
n_layers: IntegerType;
|
|
955
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
956
|
+
d_ff: OptionType<IntegerType>;
|
|
957
|
+
/** Dropout rate */
|
|
958
|
+
dropout: OptionType<FloatType>;
|
|
959
|
+
/** Whether return is per-timestep or global */
|
|
960
|
+
return_embedding: VariantType<{
|
|
961
|
+
/** Single return value for entire sequence */
|
|
962
|
+
global: NullType;
|
|
963
|
+
/** Return-to-go at each timestep */
|
|
964
|
+
per_timestep: NullType;
|
|
965
|
+
}>;
|
|
966
|
+
}>;
|
|
967
|
+
}>;
|
|
968
|
+
/** Output mode (determines loss function) */
|
|
969
|
+
output: VariantType<{
|
|
970
|
+
/** Regression: MSE loss, no activation */
|
|
971
|
+
regression: NullType;
|
|
972
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
973
|
+
binary: StructType<{
|
|
974
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
975
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
976
|
+
}>;
|
|
977
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
978
|
+
multiclass: StructType<{
|
|
979
|
+
/** Number of classes */
|
|
980
|
+
n_classes: IntegerType;
|
|
981
|
+
/** Optional per-class weights */
|
|
982
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
983
|
+
}>;
|
|
984
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
985
|
+
multi_head: StructType<{
|
|
986
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
987
|
+
n_heads: IntegerType;
|
|
988
|
+
/** Classes per head (e.g., 4 bins) */
|
|
989
|
+
n_classes_per_head: IntegerType;
|
|
990
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
991
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
992
|
+
}>;
|
|
993
|
+
/**
|
|
994
|
+
* Mixed output types per head.
|
|
995
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
996
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
997
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
998
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
999
|
+
*/
|
|
1000
|
+
multi_head_mixed: StructType<{
|
|
1001
|
+
/** Array of head configurations */
|
|
1002
|
+
heads: ArrayType<StructType<{
|
|
1003
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1004
|
+
head_type: VariantType<{
|
|
1005
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1006
|
+
binary: NullType;
|
|
1007
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1008
|
+
multiclass: StructType<{
|
|
1009
|
+
n_classes: IntegerType;
|
|
1010
|
+
}>;
|
|
1011
|
+
}>;
|
|
1012
|
+
/** Optional class weights for this head */
|
|
1013
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1014
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1015
|
+
conditional_on: OptionType<IntegerType>;
|
|
1016
|
+
}>>;
|
|
1017
|
+
}>;
|
|
1018
|
+
}>;
|
|
1019
|
+
/** Learning rate (default: 1e-3) */
|
|
1020
|
+
learning_rate: OptionType<FloatType>;
|
|
1021
|
+
/** Maximum epochs (default: 100) */
|
|
1022
|
+
max_epochs: OptionType<IntegerType>;
|
|
1023
|
+
/** Early stopping patience (default: 10) */
|
|
1024
|
+
patience: OptionType<IntegerType>;
|
|
1025
|
+
/** Batch size (default: 32) */
|
|
1026
|
+
batch_size: OptionType<IntegerType>;
|
|
1027
|
+
/** Dropout rate (default: 0.1) */
|
|
1028
|
+
dropout: OptionType<FloatType>;
|
|
1029
|
+
/** Gradient clipping value (default: 1.0) */
|
|
1030
|
+
gradient_clip: OptionType<FloatType>;
|
|
1031
|
+
/** L2 regularization weight decay (default: 0) */
|
|
1032
|
+
weight_decay: OptionType<FloatType>;
|
|
1033
|
+
/** Random seed for reproducibility */
|
|
1034
|
+
random_state: OptionType<IntegerType>;
|
|
1035
|
+
/** Optional callback called each epoch */
|
|
1036
|
+
epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
|
|
1037
|
+
}>], StructType<{
|
|
1038
|
+
/** Trained model blob */
|
|
1039
|
+
model: VariantType<{
|
|
1040
|
+
lightning: StructType<{
|
|
1041
|
+
/** Serialized model data (state_dict + hparams) */
|
|
1042
|
+
data: BlobType;
|
|
1043
|
+
/** Input dimension */
|
|
1044
|
+
n_features: IntegerType;
|
|
1045
|
+
/** Output dimension */
|
|
1046
|
+
output_dim: IntegerType;
|
|
1047
|
+
/** Architecture type */
|
|
1048
|
+
architecture_type: StringType;
|
|
1049
|
+
/** Output type */
|
|
1050
|
+
output_type: StringType;
|
|
1051
|
+
/** Latent dimension (autoencoder only) */
|
|
1052
|
+
latent_dim: OptionType<IntegerType>;
|
|
1053
|
+
}>;
|
|
1054
|
+
}>;
|
|
1055
|
+
/** Final training loss */
|
|
1056
|
+
train_loss: FloatType;
|
|
1057
|
+
/** Final validation loss */
|
|
1058
|
+
val_loss: FloatType;
|
|
1059
|
+
/** Best epoch (for early stopping) */
|
|
1060
|
+
best_epoch: IntegerType;
|
|
1061
|
+
}>>;
|
|
1062
|
+
/**
|
|
1063
|
+
* Generate action sequences autoregressively from trajectory model.
|
|
1064
|
+
*
|
|
1065
|
+
* Use with models trained via trainTrajectory.
|
|
1066
|
+
*
|
|
1067
|
+
* @param model - Trained model from trainTrajectory
|
|
1068
|
+
* @param states - State matrices: n_samples × (seq_len, state_dim)
|
|
1069
|
+
* @param target_returns - Target returns: (n_samples,)
|
|
1070
|
+
* @param config - Generation configuration
|
|
1071
|
+
* @returns Generated actions: n_samples × (seq_len, action_dim)
|
|
1072
|
+
*/
|
|
1073
|
+
export declare const lightning_generate_trajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
1074
|
+
lightning: StructType<{
|
|
1075
|
+
/** Serialized model data (state_dict + hparams) */
|
|
1076
|
+
data: BlobType;
|
|
1077
|
+
/** Input dimension */
|
|
1078
|
+
n_features: IntegerType;
|
|
1079
|
+
/** Output dimension */
|
|
1080
|
+
output_dim: IntegerType;
|
|
1081
|
+
/** Architecture type */
|
|
1082
|
+
architecture_type: StringType;
|
|
1083
|
+
/** Output type */
|
|
1084
|
+
output_type: StringType;
|
|
1085
|
+
/** Latent dimension (autoencoder only) */
|
|
1086
|
+
latent_dim: OptionType<IntegerType>;
|
|
1087
|
+
}>;
|
|
1088
|
+
}>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
|
|
1089
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
1090
|
+
temperature: FloatType;
|
|
1091
|
+
/** Whether to return probabilities or samples */
|
|
1092
|
+
return_probs: BooleanType;
|
|
1093
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
1094
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1095
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
1096
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
1097
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
1098
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
1099
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1100
|
+
head_type: VariantType<{
|
|
1101
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1102
|
+
binary: NullType;
|
|
1103
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1104
|
+
multiclass: StructType<{
|
|
1105
|
+
n_classes: IntegerType;
|
|
1106
|
+
}>;
|
|
1107
|
+
}>;
|
|
1108
|
+
/** Optional class weights for this head */
|
|
1109
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1110
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1111
|
+
conditional_on: OptionType<IntegerType>;
|
|
1112
|
+
}>>>;
|
|
1113
|
+
/** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
|
|
1114
|
+
action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1115
|
+
/** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
|
|
1116
|
+
start_timestep: OptionType<IntegerType>;
|
|
1117
|
+
}>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
|
|
627
1118
|
/**
|
|
628
1119
|
* Lightning types namespace.
|
|
629
1120
|
*/
|
|
@@ -652,6 +1143,31 @@ export declare const LightningTypes: {
|
|
|
652
1143
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
653
1144
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
654
1145
|
}>;
|
|
1146
|
+
/**
|
|
1147
|
+
* Mixed output types per head.
|
|
1148
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1149
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1150
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1151
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1152
|
+
*/
|
|
1153
|
+
multi_head_mixed: StructType<{
|
|
1154
|
+
/** Array of head configurations */
|
|
1155
|
+
heads: ArrayType<StructType<{
|
|
1156
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1157
|
+
head_type: VariantType<{
|
|
1158
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1159
|
+
binary: NullType;
|
|
1160
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1161
|
+
multiclass: StructType<{
|
|
1162
|
+
n_classes: IntegerType;
|
|
1163
|
+
}>;
|
|
1164
|
+
}>;
|
|
1165
|
+
/** Optional class weights for this head */
|
|
1166
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1167
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1168
|
+
conditional_on: OptionType<IntegerType>;
|
|
1169
|
+
}>>;
|
|
1170
|
+
}>;
|
|
655
1171
|
}>;
|
|
656
1172
|
readonly ArchitectureType: VariantType<{
|
|
657
1173
|
/** Simple MLP: input → hidden → output */
|
|
@@ -724,6 +1240,36 @@ export declare const LightningTypes: {
|
|
|
724
1240
|
/** Optional condition dimension for conditional generation */
|
|
725
1241
|
condition_dim: OptionType<IntegerType>;
|
|
726
1242
|
}>;
|
|
1243
|
+
/**
|
|
1244
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1245
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1246
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1247
|
+
*/
|
|
1248
|
+
decision_transformer: StructType<{
|
|
1249
|
+
/** Sequence length (timesteps) */
|
|
1250
|
+
sequence_length: IntegerType;
|
|
1251
|
+
/** State dimension per timestep */
|
|
1252
|
+
state_dim: IntegerType;
|
|
1253
|
+
/** Action dimension per timestep */
|
|
1254
|
+
action_dim: IntegerType;
|
|
1255
|
+
/** Model dimension (transformer hidden size) */
|
|
1256
|
+
d_model: IntegerType;
|
|
1257
|
+
/** Number of attention heads */
|
|
1258
|
+
n_attention_heads: IntegerType;
|
|
1259
|
+
/** Number of transformer layers */
|
|
1260
|
+
n_layers: IntegerType;
|
|
1261
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1262
|
+
d_ff: OptionType<IntegerType>;
|
|
1263
|
+
/** Dropout rate */
|
|
1264
|
+
dropout: OptionType<FloatType>;
|
|
1265
|
+
/** Whether return is per-timestep or global */
|
|
1266
|
+
return_embedding: VariantType<{
|
|
1267
|
+
/** Single return value for entire sequence */
|
|
1268
|
+
global: NullType;
|
|
1269
|
+
/** Return-to-go at each timestep */
|
|
1270
|
+
per_timestep: NullType;
|
|
1271
|
+
}>;
|
|
1272
|
+
}>;
|
|
727
1273
|
}>;
|
|
728
1274
|
readonly CellType: VariantType<{
|
|
729
1275
|
lstm: NullType;
|
|
@@ -803,6 +1349,36 @@ export declare const LightningTypes: {
|
|
|
803
1349
|
/** Optional condition dimension for conditional generation */
|
|
804
1350
|
condition_dim: OptionType<IntegerType>;
|
|
805
1351
|
}>;
|
|
1352
|
+
/**
|
|
1353
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1354
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1355
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1356
|
+
*/
|
|
1357
|
+
decision_transformer: StructType<{
|
|
1358
|
+
/** Sequence length (timesteps) */
|
|
1359
|
+
sequence_length: IntegerType;
|
|
1360
|
+
/** State dimension per timestep */
|
|
1361
|
+
state_dim: IntegerType;
|
|
1362
|
+
/** Action dimension per timestep */
|
|
1363
|
+
action_dim: IntegerType;
|
|
1364
|
+
/** Model dimension (transformer hidden size) */
|
|
1365
|
+
d_model: IntegerType;
|
|
1366
|
+
/** Number of attention heads */
|
|
1367
|
+
n_attention_heads: IntegerType;
|
|
1368
|
+
/** Number of transformer layers */
|
|
1369
|
+
n_layers: IntegerType;
|
|
1370
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1371
|
+
d_ff: OptionType<IntegerType>;
|
|
1372
|
+
/** Dropout rate */
|
|
1373
|
+
dropout: OptionType<FloatType>;
|
|
1374
|
+
/** Whether return is per-timestep or global */
|
|
1375
|
+
return_embedding: VariantType<{
|
|
1376
|
+
/** Single return value for entire sequence */
|
|
1377
|
+
global: NullType;
|
|
1378
|
+
/** Return-to-go at each timestep */
|
|
1379
|
+
per_timestep: NullType;
|
|
1380
|
+
}>;
|
|
1381
|
+
}>;
|
|
806
1382
|
}>;
|
|
807
1383
|
/** Output mode (determines loss function) */
|
|
808
1384
|
output: VariantType<{
|
|
@@ -829,6 +1405,31 @@ export declare const LightningTypes: {
|
|
|
829
1405
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
830
1406
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
831
1407
|
}>;
|
|
1408
|
+
/**
|
|
1409
|
+
* Mixed output types per head.
|
|
1410
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1411
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1412
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1413
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1414
|
+
*/
|
|
1415
|
+
multi_head_mixed: StructType<{
|
|
1416
|
+
/** Array of head configurations */
|
|
1417
|
+
heads: ArrayType<StructType<{
|
|
1418
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1419
|
+
head_type: VariantType<{
|
|
1420
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1421
|
+
binary: NullType;
|
|
1422
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1423
|
+
multiclass: StructType<{
|
|
1424
|
+
n_classes: IntegerType;
|
|
1425
|
+
}>;
|
|
1426
|
+
}>;
|
|
1427
|
+
/** Optional class weights for this head */
|
|
1428
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1429
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1430
|
+
conditional_on: OptionType<IntegerType>;
|
|
1431
|
+
}>>;
|
|
1432
|
+
}>;
|
|
832
1433
|
}>;
|
|
833
1434
|
/** Learning rate (default: 1e-3) */
|
|
834
1435
|
learning_rate: OptionType<FloatType>;
|
|
@@ -910,6 +1511,57 @@ export declare const LightningTypes: {
|
|
|
910
1511
|
/** If true, return probabilities. If false, return samples. */
|
|
911
1512
|
return_probs: BooleanType;
|
|
912
1513
|
}>;
|
|
1514
|
+
readonly ReturnEmbeddingType: VariantType<{
|
|
1515
|
+
/** Single return value for entire sequence */
|
|
1516
|
+
global: NullType;
|
|
1517
|
+
/** Return-to-go at each timestep */
|
|
1518
|
+
per_timestep: NullType;
|
|
1519
|
+
}>;
|
|
1520
|
+
readonly HeadConfigType: StructType<{
|
|
1521
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1522
|
+
head_type: VariantType<{
|
|
1523
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1524
|
+
binary: NullType;
|
|
1525
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1526
|
+
multiclass: StructType<{
|
|
1527
|
+
n_classes: IntegerType;
|
|
1528
|
+
}>;
|
|
1529
|
+
}>;
|
|
1530
|
+
/** Optional class weights for this head */
|
|
1531
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1532
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1533
|
+
conditional_on: OptionType<IntegerType>;
|
|
1534
|
+
}>;
|
|
1535
|
+
readonly TrajectoryGenerateConfigType: StructType<{
|
|
1536
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
1537
|
+
temperature: FloatType;
|
|
1538
|
+
/** Whether to return probabilities or samples */
|
|
1539
|
+
return_probs: BooleanType;
|
|
1540
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
1541
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1542
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
1543
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
1544
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
1545
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
1546
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1547
|
+
head_type: VariantType<{
|
|
1548
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1549
|
+
binary: NullType;
|
|
1550
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1551
|
+
multiclass: StructType<{
|
|
1552
|
+
n_classes: IntegerType;
|
|
1553
|
+
}>;
|
|
1554
|
+
}>;
|
|
1555
|
+
/** Optional class weights for this head */
|
|
1556
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1557
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1558
|
+
conditional_on: OptionType<IntegerType>;
|
|
1559
|
+
}>>>;
|
|
1560
|
+
/** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
|
|
1561
|
+
action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1562
|
+
/** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
|
|
1563
|
+
start_timestep: OptionType<IntegerType>;
|
|
1564
|
+
}>;
|
|
913
1565
|
};
|
|
914
1566
|
/**
|
|
915
1567
|
* Lightning platform functions namespace.
|
|
@@ -1015,6 +1667,36 @@ export declare const Lightning: {
|
|
|
1015
1667
|
/** Optional condition dimension for conditional generation */
|
|
1016
1668
|
condition_dim: OptionType<IntegerType>;
|
|
1017
1669
|
}>;
|
|
1670
|
+
/**
|
|
1671
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
1672
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
1673
|
+
* Predicts actions conditioned on desired return and state history.
|
|
1674
|
+
*/
|
|
1675
|
+
decision_transformer: StructType<{
|
|
1676
|
+
/** Sequence length (timesteps) */
|
|
1677
|
+
sequence_length: IntegerType;
|
|
1678
|
+
/** State dimension per timestep */
|
|
1679
|
+
state_dim: IntegerType;
|
|
1680
|
+
/** Action dimension per timestep */
|
|
1681
|
+
action_dim: IntegerType;
|
|
1682
|
+
/** Model dimension (transformer hidden size) */
|
|
1683
|
+
d_model: IntegerType;
|
|
1684
|
+
/** Number of attention heads */
|
|
1685
|
+
n_attention_heads: IntegerType;
|
|
1686
|
+
/** Number of transformer layers */
|
|
1687
|
+
n_layers: IntegerType;
|
|
1688
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
1689
|
+
d_ff: OptionType<IntegerType>;
|
|
1690
|
+
/** Dropout rate */
|
|
1691
|
+
dropout: OptionType<FloatType>;
|
|
1692
|
+
/** Whether return is per-timestep or global */
|
|
1693
|
+
return_embedding: VariantType<{
|
|
1694
|
+
/** Single return value for entire sequence */
|
|
1695
|
+
global: NullType;
|
|
1696
|
+
/** Return-to-go at each timestep */
|
|
1697
|
+
per_timestep: NullType;
|
|
1698
|
+
}>;
|
|
1699
|
+
}>;
|
|
1018
1700
|
}>;
|
|
1019
1701
|
/** Output mode (determines loss function) */
|
|
1020
1702
|
output: VariantType<{
|
|
@@ -1041,6 +1723,31 @@ export declare const Lightning: {
|
|
|
1041
1723
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1042
1724
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1043
1725
|
}>;
|
|
1726
|
+
/**
|
|
1727
|
+
* Mixed output types per head.
|
|
1728
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
1729
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
1730
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
1731
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
1732
|
+
*/
|
|
1733
|
+
multi_head_mixed: StructType<{
|
|
1734
|
+
/** Array of head configurations */
|
|
1735
|
+
heads: ArrayType<StructType<{
|
|
1736
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
1737
|
+
head_type: VariantType<{
|
|
1738
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
1739
|
+
binary: NullType;
|
|
1740
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
1741
|
+
multiclass: StructType<{
|
|
1742
|
+
n_classes: IntegerType;
|
|
1743
|
+
}>;
|
|
1744
|
+
}>;
|
|
1745
|
+
/** Optional class weights for this head */
|
|
1746
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
1747
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
1748
|
+
conditional_on: OptionType<IntegerType>;
|
|
1749
|
+
}>>;
|
|
1750
|
+
}>;
|
|
1044
1751
|
}>;
|
|
1045
1752
|
/** Learning rate (default: 1e-3) */
|
|
1046
1753
|
learning_rate: OptionType<FloatType>;
|
|
@@ -1210,6 +1917,283 @@ export declare const Lightning: {
|
|
|
1210
1917
|
/** If true, return probabilities. If false, return samples. */
|
|
1211
1918
|
return_probs: BooleanType;
|
|
1212
1919
|
}>], ArrayType<ArrayType<FloatType>>>;
|
|
1920
|
+
/**
|
|
1921
|
+
* Train a Decision Transformer with trajectory data.
|
|
1922
|
+
*
|
|
1923
|
+
* Trains a return-conditioned sequence generation model that learns
|
|
1924
|
+
* to predict actions given states and desired returns.
|
|
1925
|
+
*
|
|
1926
|
+
* @example
|
|
1927
|
+
* ```typescript
|
|
1928
|
+
* const result = Lightning.trainTrajectory(
|
|
1929
|
+
* returns, states, actions, masks,
|
|
1930
|
+
* {
|
|
1931
|
+
* architecture: variant("decision_transformer", {
|
|
1932
|
+
* sequence_length: 14n,
|
|
1933
|
+
* state_dim: 8n,
|
|
1934
|
+
* action_dim: 11n,
|
|
1935
|
+
* d_model: 64n,
|
|
1936
|
+
* n_attention_heads: 4n,
|
|
1937
|
+
* n_layers: 3n,
|
|
1938
|
+
* d_ff: variant("none", null),
|
|
1939
|
+
* dropout: variant("some", 0.1),
|
|
1940
|
+
* return_embedding: variant("global", null),
|
|
1941
|
+
* }),
|
|
1942
|
+
* output: variant("multi_head_mixed", { heads: [...] }),
|
|
1943
|
+
* ...
|
|
1944
|
+
* }
|
|
1945
|
+
* );
|
|
1946
|
+
* ```
|
|
1947
|
+
*/
|
|
1948
|
+
readonly trainTrajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
|
|
1949
|
+
/** Model architecture */
|
|
1950
|
+
architecture: VariantType<{
|
|
1951
|
+
/** Simple MLP: input → hidden → output */
|
|
1952
|
+
mlp: StructType<{
|
|
1953
|
+
/** Hidden layer sizes */
|
|
1954
|
+
hidden_layers: ArrayType<IntegerType>;
|
|
1955
|
+
}>;
|
|
1956
|
+
/** Autoencoder: input → encoder → latent → decoder → output */
|
|
1957
|
+
autoencoder: StructType<{
|
|
1958
|
+
/** Encoder hidden layer sizes */
|
|
1959
|
+
encoder_layers: ArrayType<IntegerType>;
|
|
1960
|
+
/** Latent dimension (bottleneck) */
|
|
1961
|
+
latent_dim: IntegerType;
|
|
1962
|
+
/** Decoder hidden layer sizes */
|
|
1963
|
+
decoder_layers: ArrayType<IntegerType>;
|
|
1964
|
+
}>;
|
|
1965
|
+
/** Conv1D: 1D convolutional autoencoder for temporal patterns */
|
|
1966
|
+
conv1d: StructType<{
|
|
1967
|
+
/** Number of channels (e.g., additive types) */
|
|
1968
|
+
n_channels: IntegerType;
|
|
1969
|
+
/** Sequence length (e.g., days) */
|
|
1970
|
+
sequence_length: IntegerType;
|
|
1971
|
+
/** Conv layer channel sizes */
|
|
1972
|
+
conv_channels: ArrayType<IntegerType>;
|
|
1973
|
+
/** Kernel size for convolutions (must be odd) */
|
|
1974
|
+
kernel_size: IntegerType;
|
|
1975
|
+
/** Latent dimension after flattening */
|
|
1976
|
+
latent_dim: IntegerType;
|
|
1977
|
+
/** Optional condition dimension for conditional generation */
|
|
1978
|
+
condition_dim: OptionType<IntegerType>;
|
|
1979
|
+
}>;
|
|
1980
|
+
/** Sequential: LSTM/GRU autoencoder for long-range dependencies */
|
|
1981
|
+
sequential: StructType<{
|
|
1982
|
+
/** Number of channels (e.g., additive types) */
|
|
1983
|
+
n_channels: IntegerType;
|
|
1984
|
+
/** Sequence length (e.g., days) */
|
|
1985
|
+
sequence_length: IntegerType;
|
|
1986
|
+
/** RNN hidden size */
|
|
1987
|
+
hidden_size: IntegerType;
|
|
1988
|
+
/** Number of RNN layers */
|
|
1989
|
+
n_layers: IntegerType;
|
|
1990
|
+
/** Cell type: lstm or gru */
|
|
1991
|
+
cell_type: VariantType<{
|
|
1992
|
+
lstm: NullType;
|
|
1993
|
+
gru: NullType;
|
|
1994
|
+
}>;
|
|
1995
|
+
/** Latent dimension (from final hidden state) */
|
|
1996
|
+
latent_dim: IntegerType;
|
|
1997
|
+
/** Bidirectional encoder (decoder is always unidirectional) */
|
|
1998
|
+
bidirectional: BooleanType;
|
|
1999
|
+
/** Optional condition dimension for conditional generation */
|
|
2000
|
+
condition_dim: OptionType<IntegerType>;
|
|
2001
|
+
}>;
|
|
2002
|
+
/** Transformer: attention-based autoencoder for complex patterns */
|
|
2003
|
+
transformer: StructType<{
|
|
2004
|
+
/** Number of channels (e.g., additive types) */
|
|
2005
|
+
n_channels: IntegerType;
|
|
2006
|
+
/** Sequence length (e.g., days) */
|
|
2007
|
+
sequence_length: IntegerType;
|
|
2008
|
+
/** Model dimension */
|
|
2009
|
+
d_model: IntegerType;
|
|
2010
|
+
/** Number of attention heads (must divide d_model evenly) */
|
|
2011
|
+
n_attention_heads: IntegerType;
|
|
2012
|
+
/** Number of transformer layers */
|
|
2013
|
+
n_layers: IntegerType;
|
|
2014
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2015
|
+
d_ff: OptionType<IntegerType>;
|
|
2016
|
+
/** Latent dimension (mean pooled output) */
|
|
2017
|
+
latent_dim: IntegerType;
|
|
2018
|
+
/** Optional condition dimension for conditional generation */
|
|
2019
|
+
condition_dim: OptionType<IntegerType>;
|
|
2020
|
+
}>;
|
|
2021
|
+
/**
|
|
2022
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2023
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2024
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2025
|
+
*/
|
|
2026
|
+
decision_transformer: StructType<{
|
|
2027
|
+
/** Sequence length (timesteps) */
|
|
2028
|
+
sequence_length: IntegerType;
|
|
2029
|
+
/** State dimension per timestep */
|
|
2030
|
+
state_dim: IntegerType;
|
|
2031
|
+
/** Action dimension per timestep */
|
|
2032
|
+
action_dim: IntegerType;
|
|
2033
|
+
/** Model dimension (transformer hidden size) */
|
|
2034
|
+
d_model: IntegerType;
|
|
2035
|
+
/** Number of attention heads */
|
|
2036
|
+
n_attention_heads: IntegerType;
|
|
2037
|
+
/** Number of transformer layers */
|
|
2038
|
+
n_layers: IntegerType;
|
|
2039
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2040
|
+
d_ff: OptionType<IntegerType>;
|
|
2041
|
+
/** Dropout rate */
|
|
2042
|
+
dropout: OptionType<FloatType>;
|
|
2043
|
+
/** Whether return is per-timestep or global */
|
|
2044
|
+
return_embedding: VariantType<{
|
|
2045
|
+
/** Single return value for entire sequence */
|
|
2046
|
+
global: NullType;
|
|
2047
|
+
/** Return-to-go at each timestep */
|
|
2048
|
+
per_timestep: NullType;
|
|
2049
|
+
}>;
|
|
2050
|
+
}>;
|
|
2051
|
+
}>;
|
|
2052
|
+
/** Output mode (determines loss function) */
|
|
2053
|
+
output: VariantType<{
|
|
2054
|
+
/** Regression: MSE loss, no activation */
|
|
2055
|
+
regression: NullType;
|
|
2056
|
+
/** Binary: BCE loss, sigmoid activation */
|
|
2057
|
+
binary: StructType<{
|
|
2058
|
+
/** Optional per-position pos_weights for class imbalance [output_dim] */
|
|
2059
|
+
pos_weight: OptionType<ArrayType<FloatType>>;
|
|
2060
|
+
}>;
|
|
2061
|
+
/** Multiclass: CrossEntropy loss, softmax activation */
|
|
2062
|
+
multiclass: StructType<{
|
|
2063
|
+
/** Number of classes */
|
|
2064
|
+
n_classes: IntegerType;
|
|
2065
|
+
/** Optional per-class weights */
|
|
2066
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2067
|
+
}>;
|
|
2068
|
+
/** Multi-head categorical: N independent CrossEntropy heads */
|
|
2069
|
+
multi_head: StructType<{
|
|
2070
|
+
/** Number of heads (e.g., 84 time slots) */
|
|
2071
|
+
n_heads: IntegerType;
|
|
2072
|
+
/** Classes per head (e.g., 4 bins) */
|
|
2073
|
+
n_classes_per_head: IntegerType;
|
|
2074
|
+
/** Optional class weights matrix (n_heads, n_classes) */
|
|
2075
|
+
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2076
|
+
}>;
|
|
2077
|
+
/**
|
|
2078
|
+
* Mixed output types per head.
|
|
2079
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2080
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2081
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2082
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2083
|
+
*/
|
|
2084
|
+
multi_head_mixed: StructType<{
|
|
2085
|
+
/** Array of head configurations */
|
|
2086
|
+
heads: ArrayType<StructType<{
|
|
2087
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2088
|
+
head_type: VariantType<{
|
|
2089
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2090
|
+
binary: NullType;
|
|
2091
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2092
|
+
multiclass: StructType<{
|
|
2093
|
+
n_classes: IntegerType;
|
|
2094
|
+
}>;
|
|
2095
|
+
}>;
|
|
2096
|
+
/** Optional class weights for this head */
|
|
2097
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2098
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2099
|
+
conditional_on: OptionType<IntegerType>;
|
|
2100
|
+
}>>;
|
|
2101
|
+
}>;
|
|
2102
|
+
}>;
|
|
2103
|
+
/** Learning rate (default: 1e-3) */
|
|
2104
|
+
learning_rate: OptionType<FloatType>;
|
|
2105
|
+
/** Maximum epochs (default: 100) */
|
|
2106
|
+
max_epochs: OptionType<IntegerType>;
|
|
2107
|
+
/** Early stopping patience (default: 10) */
|
|
2108
|
+
patience: OptionType<IntegerType>;
|
|
2109
|
+
/** Batch size (default: 32) */
|
|
2110
|
+
batch_size: OptionType<IntegerType>;
|
|
2111
|
+
/** Dropout rate (default: 0.1) */
|
|
2112
|
+
dropout: OptionType<FloatType>;
|
|
2113
|
+
/** Gradient clipping value (default: 1.0) */
|
|
2114
|
+
gradient_clip: OptionType<FloatType>;
|
|
2115
|
+
/** L2 regularization weight decay (default: 0) */
|
|
2116
|
+
weight_decay: OptionType<FloatType>;
|
|
2117
|
+
/** Random seed for reproducibility */
|
|
2118
|
+
random_state: OptionType<IntegerType>;
|
|
2119
|
+
/** Optional callback called each epoch */
|
|
2120
|
+
epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
|
|
2121
|
+
}>], StructType<{
|
|
2122
|
+
/** Trained model blob */
|
|
2123
|
+
model: VariantType<{
|
|
2124
|
+
lightning: StructType<{
|
|
2125
|
+
/** Serialized model data (state_dict + hparams) */
|
|
2126
|
+
data: BlobType;
|
|
2127
|
+
/** Input dimension */
|
|
2128
|
+
n_features: IntegerType;
|
|
2129
|
+
/** Output dimension */
|
|
2130
|
+
output_dim: IntegerType;
|
|
2131
|
+
/** Architecture type */
|
|
2132
|
+
architecture_type: StringType;
|
|
2133
|
+
/** Output type */
|
|
2134
|
+
output_type: StringType;
|
|
2135
|
+
/** Latent dimension (autoencoder only) */
|
|
2136
|
+
latent_dim: OptionType<IntegerType>;
|
|
2137
|
+
}>;
|
|
2138
|
+
}>;
|
|
2139
|
+
/** Final training loss */
|
|
2140
|
+
train_loss: FloatType;
|
|
2141
|
+
/** Final validation loss */
|
|
2142
|
+
val_loss: FloatType;
|
|
2143
|
+
/** Best epoch (for early stopping) */
|
|
2144
|
+
best_epoch: IntegerType;
|
|
2145
|
+
}>>;
|
|
2146
|
+
/**
|
|
2147
|
+
* Generate action sequences from a Decision Transformer.
|
|
2148
|
+
*
|
|
2149
|
+
* Autoregressively generates actions conditioned on target returns
|
|
2150
|
+
* and state sequences.
|
|
2151
|
+
*/
|
|
2152
|
+
readonly generateTrajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
|
|
2153
|
+
lightning: StructType<{
|
|
2154
|
+
/** Serialized model data (state_dict + hparams) */
|
|
2155
|
+
data: BlobType;
|
|
2156
|
+
/** Input dimension */
|
|
2157
|
+
n_features: IntegerType;
|
|
2158
|
+
/** Output dimension */
|
|
2159
|
+
output_dim: IntegerType;
|
|
2160
|
+
/** Architecture type */
|
|
2161
|
+
architecture_type: StringType;
|
|
2162
|
+
/** Output type */
|
|
2163
|
+
output_type: StringType;
|
|
2164
|
+
/** Latent dimension (autoencoder only) */
|
|
2165
|
+
latent_dim: OptionType<IntegerType>;
|
|
2166
|
+
}>;
|
|
2167
|
+
}>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
|
|
2168
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
2169
|
+
temperature: FloatType;
|
|
2170
|
+
/** Whether to return probabilities or samples */
|
|
2171
|
+
return_probs: BooleanType;
|
|
2172
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
2173
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2174
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
2175
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
2176
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
2177
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
2178
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2179
|
+
head_type: VariantType<{
|
|
2180
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2181
|
+
binary: NullType;
|
|
2182
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2183
|
+
multiclass: StructType<{
|
|
2184
|
+
n_classes: IntegerType;
|
|
2185
|
+
}>;
|
|
2186
|
+
}>;
|
|
2187
|
+
/** Optional class weights for this head */
|
|
2188
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2189
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2190
|
+
conditional_on: OptionType<IntegerType>;
|
|
2191
|
+
}>>>;
|
|
2192
|
+
/** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
|
|
2193
|
+
action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2194
|
+
/** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
|
|
2195
|
+
start_timestep: OptionType<IntegerType>;
|
|
2196
|
+
}>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
|
|
1213
2197
|
/**
|
|
1214
2198
|
* Type definitions for Lightning functions.
|
|
1215
2199
|
*/
|
|
@@ -1238,6 +2222,31 @@ export declare const Lightning: {
|
|
|
1238
2222
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1239
2223
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1240
2224
|
}>;
|
|
2225
|
+
/**
|
|
2226
|
+
* Mixed output types per head.
|
|
2227
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2228
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2229
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2230
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2231
|
+
*/
|
|
2232
|
+
multi_head_mixed: StructType<{
|
|
2233
|
+
/** Array of head configurations */
|
|
2234
|
+
heads: ArrayType<StructType<{
|
|
2235
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2236
|
+
head_type: VariantType<{
|
|
2237
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2238
|
+
binary: NullType;
|
|
2239
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2240
|
+
multiclass: StructType<{
|
|
2241
|
+
n_classes: IntegerType;
|
|
2242
|
+
}>;
|
|
2243
|
+
}>;
|
|
2244
|
+
/** Optional class weights for this head */
|
|
2245
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2246
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2247
|
+
conditional_on: OptionType<IntegerType>;
|
|
2248
|
+
}>>;
|
|
2249
|
+
}>;
|
|
1241
2250
|
}>;
|
|
1242
2251
|
readonly ArchitectureType: VariantType<{
|
|
1243
2252
|
/** Simple MLP: input → hidden → output */
|
|
@@ -1310,6 +2319,36 @@ export declare const Lightning: {
|
|
|
1310
2319
|
/** Optional condition dimension for conditional generation */
|
|
1311
2320
|
condition_dim: OptionType<IntegerType>;
|
|
1312
2321
|
}>;
|
|
2322
|
+
/**
|
|
2323
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2324
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2325
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2326
|
+
*/
|
|
2327
|
+
decision_transformer: StructType<{
|
|
2328
|
+
/** Sequence length (timesteps) */
|
|
2329
|
+
sequence_length: IntegerType;
|
|
2330
|
+
/** State dimension per timestep */
|
|
2331
|
+
state_dim: IntegerType;
|
|
2332
|
+
/** Action dimension per timestep */
|
|
2333
|
+
action_dim: IntegerType;
|
|
2334
|
+
/** Model dimension (transformer hidden size) */
|
|
2335
|
+
d_model: IntegerType;
|
|
2336
|
+
/** Number of attention heads */
|
|
2337
|
+
n_attention_heads: IntegerType;
|
|
2338
|
+
/** Number of transformer layers */
|
|
2339
|
+
n_layers: IntegerType;
|
|
2340
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2341
|
+
d_ff: OptionType<IntegerType>;
|
|
2342
|
+
/** Dropout rate */
|
|
2343
|
+
dropout: OptionType<FloatType>;
|
|
2344
|
+
/** Whether return is per-timestep or global */
|
|
2345
|
+
return_embedding: VariantType<{
|
|
2346
|
+
/** Single return value for entire sequence */
|
|
2347
|
+
global: NullType;
|
|
2348
|
+
/** Return-to-go at each timestep */
|
|
2349
|
+
per_timestep: NullType;
|
|
2350
|
+
}>;
|
|
2351
|
+
}>;
|
|
1313
2352
|
}>;
|
|
1314
2353
|
readonly CellType: VariantType<{
|
|
1315
2354
|
lstm: NullType;
|
|
@@ -1389,6 +2428,36 @@ export declare const Lightning: {
|
|
|
1389
2428
|
/** Optional condition dimension for conditional generation */
|
|
1390
2429
|
condition_dim: OptionType<IntegerType>;
|
|
1391
2430
|
}>;
|
|
2431
|
+
/**
|
|
2432
|
+
* Decision Transformer: return-conditioned sequence generation.
|
|
2433
|
+
* Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
|
|
2434
|
+
* Predicts actions conditioned on desired return and state history.
|
|
2435
|
+
*/
|
|
2436
|
+
decision_transformer: StructType<{
|
|
2437
|
+
/** Sequence length (timesteps) */
|
|
2438
|
+
sequence_length: IntegerType;
|
|
2439
|
+
/** State dimension per timestep */
|
|
2440
|
+
state_dim: IntegerType;
|
|
2441
|
+
/** Action dimension per timestep */
|
|
2442
|
+
action_dim: IntegerType;
|
|
2443
|
+
/** Model dimension (transformer hidden size) */
|
|
2444
|
+
d_model: IntegerType;
|
|
2445
|
+
/** Number of attention heads */
|
|
2446
|
+
n_attention_heads: IntegerType;
|
|
2447
|
+
/** Number of transformer layers */
|
|
2448
|
+
n_layers: IntegerType;
|
|
2449
|
+
/** Feedforward dimension (default: 4 * d_model) */
|
|
2450
|
+
d_ff: OptionType<IntegerType>;
|
|
2451
|
+
/** Dropout rate */
|
|
2452
|
+
dropout: OptionType<FloatType>;
|
|
2453
|
+
/** Whether return is per-timestep or global */
|
|
2454
|
+
return_embedding: VariantType<{
|
|
2455
|
+
/** Single return value for entire sequence */
|
|
2456
|
+
global: NullType;
|
|
2457
|
+
/** Return-to-go at each timestep */
|
|
2458
|
+
per_timestep: NullType;
|
|
2459
|
+
}>;
|
|
2460
|
+
}>;
|
|
1392
2461
|
}>;
|
|
1393
2462
|
/** Output mode (determines loss function) */
|
|
1394
2463
|
output: VariantType<{
|
|
@@ -1415,6 +2484,31 @@ export declare const Lightning: {
|
|
|
1415
2484
|
/** Optional class weights matrix (n_heads, n_classes) */
|
|
1416
2485
|
class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
1417
2486
|
}>;
|
|
2487
|
+
/**
|
|
2488
|
+
* Mixed output types per head.
|
|
2489
|
+
* For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
|
|
2490
|
+
* Binary heads: 1 logit → sigmoid → BCE loss
|
|
2491
|
+
* Multiclass heads: n_classes logits → softmax → CE loss
|
|
2492
|
+
* Action vectors use one-hot encoding for multiclass heads.
|
|
2493
|
+
*/
|
|
2494
|
+
multi_head_mixed: StructType<{
|
|
2495
|
+
/** Array of head configurations */
|
|
2496
|
+
heads: ArrayType<StructType<{
|
|
2497
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2498
|
+
head_type: VariantType<{
|
|
2499
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2500
|
+
binary: NullType;
|
|
2501
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2502
|
+
multiclass: StructType<{
|
|
2503
|
+
n_classes: IntegerType;
|
|
2504
|
+
}>;
|
|
2505
|
+
}>;
|
|
2506
|
+
/** Optional class weights for this head */
|
|
2507
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2508
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2509
|
+
conditional_on: OptionType<IntegerType>;
|
|
2510
|
+
}>>;
|
|
2511
|
+
}>;
|
|
1418
2512
|
}>;
|
|
1419
2513
|
/** Learning rate (default: 1e-3) */
|
|
1420
2514
|
learning_rate: OptionType<FloatType>;
|
|
@@ -1496,6 +2590,57 @@ export declare const Lightning: {
|
|
|
1496
2590
|
/** If true, return probabilities. If false, return samples. */
|
|
1497
2591
|
return_probs: BooleanType;
|
|
1498
2592
|
}>;
|
|
2593
|
+
readonly ReturnEmbeddingType: VariantType<{
|
|
2594
|
+
/** Single return value for entire sequence */
|
|
2595
|
+
global: NullType;
|
|
2596
|
+
/** Return-to-go at each timestep */
|
|
2597
|
+
per_timestep: NullType;
|
|
2598
|
+
}>;
|
|
2599
|
+
readonly HeadConfigType: StructType<{
|
|
2600
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2601
|
+
head_type: VariantType<{
|
|
2602
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2603
|
+
binary: NullType;
|
|
2604
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2605
|
+
multiclass: StructType<{
|
|
2606
|
+
n_classes: IntegerType;
|
|
2607
|
+
}>;
|
|
2608
|
+
}>;
|
|
2609
|
+
/** Optional class weights for this head */
|
|
2610
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2611
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2612
|
+
conditional_on: OptionType<IntegerType>;
|
|
2613
|
+
}>;
|
|
2614
|
+
readonly TrajectoryGenerateConfigType: StructType<{
|
|
2615
|
+
/** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
|
|
2616
|
+
temperature: FloatType;
|
|
2617
|
+
/** Whether to return probabilities or samples */
|
|
2618
|
+
return_probs: BooleanType;
|
|
2619
|
+
/** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
|
|
2620
|
+
action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2621
|
+
/** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
|
|
2622
|
+
temporal_mask: OptionType<ArrayType<FloatType>>;
|
|
2623
|
+
/** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
|
|
2624
|
+
head_configs: OptionType<ArrayType<StructType<{
|
|
2625
|
+
/** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
|
|
2626
|
+
head_type: VariantType<{
|
|
2627
|
+
/** Single binary output: 1 logit, sigmoid, BCE loss */
|
|
2628
|
+
binary: NullType;
|
|
2629
|
+
/** Multi-class output: n_classes logits, softmax, CE loss */
|
|
2630
|
+
multiclass: StructType<{
|
|
2631
|
+
n_classes: IntegerType;
|
|
2632
|
+
}>;
|
|
2633
|
+
}>;
|
|
2634
|
+
/** Optional class weights for this head */
|
|
2635
|
+
class_weights: OptionType<ArrayType<FloatType>>;
|
|
2636
|
+
/** Optional: index of head this depends on (loss only computed when that head is 1) */
|
|
2637
|
+
conditional_on: OptionType<IntegerType>;
|
|
2638
|
+
}>>>;
|
|
2639
|
+
/** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
|
|
2640
|
+
action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
|
|
2641
|
+
/** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
|
|
2642
|
+
start_timestep: OptionType<IntegerType>;
|
|
2643
|
+
}>;
|
|
1499
2644
|
};
|
|
1500
2645
|
};
|
|
1501
2646
|
//# sourceMappingURL=lightning.d.ts.map
|