@elaraai/east-py-datascience 0.0.2-beta.32 → 0.0.2-beta.34

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,33 @@
13
13
  */
14
14
  import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType, NullType, BooleanType, FunctionType, StringType } from "@elaraai/east";
15
15
  export { VectorType, MatrixType } from "../types.js";
16
+ /**
17
+ * Return embedding mode for Decision Transformer.
18
+ */
19
+ export declare const ReturnEmbeddingType: VariantType<{
20
+ /** Single return value for entire sequence */
21
+ global: NullType;
22
+ /** Return-to-go at each timestep */
23
+ per_timestep: NullType;
24
+ }>;
25
+ /**
26
+ * Per-head output configuration for multi_head_mixed.
27
+ */
28
+ export declare const HeadConfigType: StructType<{
29
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
30
+ head_type: VariantType<{
31
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
32
+ binary: NullType;
33
+ /** Multi-class output: n_classes logits, softmax, CE loss */
34
+ multiclass: StructType<{
35
+ n_classes: IntegerType;
36
+ }>;
37
+ }>;
38
+ /** Optional class weights for this head */
39
+ class_weights: OptionType<ArrayType<FloatType>>;
40
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
41
+ conditional_on: OptionType<IntegerType>;
42
+ }>;
16
43
  /**
17
44
  * Lightning output mode - determines loss function and output activation.
18
45
  */
@@ -40,6 +67,31 @@ export declare const LightningOutputType: VariantType<{
40
67
  /** Optional class weights matrix (n_heads, n_classes) */
41
68
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
42
69
  }>;
70
+ /**
71
+ * Mixed output types per head.
72
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
73
+ * Binary heads: 1 logit → sigmoid → BCE loss
74
+ * Multiclass heads: n_classes logits → softmax → CE loss
75
+ * Action vectors use one-hot encoding for multiclass heads.
76
+ */
77
+ multi_head_mixed: StructType<{
78
+ /** Array of head configurations */
79
+ heads: ArrayType<StructType<{
80
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
81
+ head_type: VariantType<{
82
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
83
+ binary: NullType;
84
+ /** Multi-class output: n_classes logits, softmax, CE loss */
85
+ multiclass: StructType<{
86
+ n_classes: IntegerType;
87
+ }>;
88
+ }>;
89
+ /** Optional class weights for this head */
90
+ class_weights: OptionType<ArrayType<FloatType>>;
91
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
92
+ conditional_on: OptionType<IntegerType>;
93
+ }>>;
94
+ }>;
43
95
  }>;
44
96
  /**
45
97
  * Cell type for sequential architectures.
@@ -122,6 +174,36 @@ export declare const LightningArchitectureType: VariantType<{
122
174
  /** Optional condition dimension for conditional generation */
123
175
  condition_dim: OptionType<IntegerType>;
124
176
  }>;
177
+ /**
178
+ * Decision Transformer: return-conditioned sequence generation.
179
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
180
+ * Predicts actions conditioned on desired return and state history.
181
+ */
182
+ decision_transformer: StructType<{
183
+ /** Sequence length (timesteps) */
184
+ sequence_length: IntegerType;
185
+ /** State dimension per timestep */
186
+ state_dim: IntegerType;
187
+ /** Action dimension per timestep */
188
+ action_dim: IntegerType;
189
+ /** Model dimension (transformer hidden size) */
190
+ d_model: IntegerType;
191
+ /** Number of attention heads */
192
+ n_attention_heads: IntegerType;
193
+ /** Number of transformer layers */
194
+ n_layers: IntegerType;
195
+ /** Feedforward dimension (default: 4 * d_model) */
196
+ d_ff: OptionType<IntegerType>;
197
+ /** Dropout rate */
198
+ dropout: OptionType<FloatType>;
199
+ /** Whether return is per-timestep or global */
200
+ return_embedding: VariantType<{
201
+ /** Single return value for entire sequence */
202
+ global: NullType;
203
+ /** Return-to-go at each timestep */
204
+ per_timestep: NullType;
205
+ }>;
206
+ }>;
125
207
  }>;
126
208
  /**
127
209
  * Epoch callback function type: (epoch, train_loss, val_loss) -> void
@@ -203,6 +285,36 @@ export declare const LightningConfigType: StructType<{
203
285
  /** Optional condition dimension for conditional generation */
204
286
  condition_dim: OptionType<IntegerType>;
205
287
  }>;
288
+ /**
289
+ * Decision Transformer: return-conditioned sequence generation.
290
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
291
+ * Predicts actions conditioned on desired return and state history.
292
+ */
293
+ decision_transformer: StructType<{
294
+ /** Sequence length (timesteps) */
295
+ sequence_length: IntegerType;
296
+ /** State dimension per timestep */
297
+ state_dim: IntegerType;
298
+ /** Action dimension per timestep */
299
+ action_dim: IntegerType;
300
+ /** Model dimension (transformer hidden size) */
301
+ d_model: IntegerType;
302
+ /** Number of attention heads */
303
+ n_attention_heads: IntegerType;
304
+ /** Number of transformer layers */
305
+ n_layers: IntegerType;
306
+ /** Feedforward dimension (default: 4 * d_model) */
307
+ d_ff: OptionType<IntegerType>;
308
+ /** Dropout rate */
309
+ dropout: OptionType<FloatType>;
310
+ /** Whether return is per-timestep or global */
311
+ return_embedding: VariantType<{
312
+ /** Single return value for entire sequence */
313
+ global: NullType;
314
+ /** Return-to-go at each timestep */
315
+ per_timestep: NullType;
316
+ }>;
317
+ }>;
206
318
  }>;
207
319
  /** Output mode (determines loss function) */
208
320
  output: VariantType<{
@@ -229,6 +341,31 @@ export declare const LightningConfigType: StructType<{
229
341
  /** Optional class weights matrix (n_heads, n_classes) */
230
342
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
231
343
  }>;
344
+ /**
345
+ * Mixed output types per head.
346
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
347
+ * Binary heads: 1 logit → sigmoid → BCE loss
348
+ * Multiclass heads: n_classes logits → softmax → CE loss
349
+ * Action vectors use one-hot encoding for multiclass heads.
350
+ */
351
+ multi_head_mixed: StructType<{
352
+ /** Array of head configurations */
353
+ heads: ArrayType<StructType<{
354
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
355
+ head_type: VariantType<{
356
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
357
+ binary: NullType;
358
+ /** Multi-class output: n_classes logits, softmax, CE loss */
359
+ multiclass: StructType<{
360
+ n_classes: IntegerType;
361
+ }>;
362
+ }>;
363
+ /** Optional class weights for this head */
364
+ class_weights: OptionType<ArrayType<FloatType>>;
365
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
366
+ conditional_on: OptionType<IntegerType>;
367
+ }>>;
368
+ }>;
232
369
  }>;
233
370
  /** Learning rate (default: 1e-3) */
234
371
  learning_rate: OptionType<FloatType>;
@@ -401,6 +538,36 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
401
538
  /** Optional condition dimension for conditional generation */
402
539
  condition_dim: OptionType<IntegerType>;
403
540
  }>;
541
+ /**
542
+ * Decision Transformer: return-conditioned sequence generation.
543
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
544
+ * Predicts actions conditioned on desired return and state history.
545
+ */
546
+ decision_transformer: StructType<{
547
+ /** Sequence length (timesteps) */
548
+ sequence_length: IntegerType;
549
+ /** State dimension per timestep */
550
+ state_dim: IntegerType;
551
+ /** Action dimension per timestep */
552
+ action_dim: IntegerType;
553
+ /** Model dimension (transformer hidden size) */
554
+ d_model: IntegerType;
555
+ /** Number of attention heads */
556
+ n_attention_heads: IntegerType;
557
+ /** Number of transformer layers */
558
+ n_layers: IntegerType;
559
+ /** Feedforward dimension (default: 4 * d_model) */
560
+ d_ff: OptionType<IntegerType>;
561
+ /** Dropout rate */
562
+ dropout: OptionType<FloatType>;
563
+ /** Whether return is per-timestep or global */
564
+ return_embedding: VariantType<{
565
+ /** Single return value for entire sequence */
566
+ global: NullType;
567
+ /** Return-to-go at each timestep */
568
+ per_timestep: NullType;
569
+ }>;
570
+ }>;
404
571
  }>;
405
572
  /** Output mode (determines loss function) */
406
573
  output: VariantType<{
@@ -427,6 +594,31 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
427
594
  /** Optional class weights matrix (n_heads, n_classes) */
428
595
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
429
596
  }>;
597
+ /**
598
+ * Mixed output types per head.
599
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
600
+ * Binary heads: 1 logit → sigmoid → BCE loss
601
+ * Multiclass heads: n_classes logits → softmax → CE loss
602
+ * Action vectors use one-hot encoding for multiclass heads.
603
+ */
604
+ multi_head_mixed: StructType<{
605
+ /** Array of head configurations */
606
+ heads: ArrayType<StructType<{
607
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
608
+ head_type: VariantType<{
609
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
610
+ binary: NullType;
611
+ /** Multi-class output: n_classes logits, softmax, CE loss */
612
+ multiclass: StructType<{
613
+ n_classes: IntegerType;
614
+ }>;
615
+ }>;
616
+ /** Optional class weights for this head */
617
+ class_weights: OptionType<ArrayType<FloatType>>;
618
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
619
+ conditional_on: OptionType<IntegerType>;
620
+ }>>;
621
+ }>;
430
622
  }>;
431
623
  /** Learning rate (default: 1e-3) */
432
624
  learning_rate: OptionType<FloatType>;
@@ -624,6 +816,305 @@ export declare const lightning_generate_sequence: import("@elaraai/east").Platfo
624
816
  /** If true, return probabilities. If false, return samples. */
625
817
  return_probs: BooleanType;
626
818
  }>], ArrayType<ArrayType<FloatType>>>;
819
+ /**
820
+ * Configuration for Decision Transformer trajectory generation.
821
+ */
822
+ export declare const TrajectoryGenerateConfigType: StructType<{
823
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
824
+ temperature: FloatType;
825
+ /** Whether to return probabilities or samples */
826
+ return_probs: BooleanType;
827
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
828
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
829
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
830
+ temporal_mask: OptionType<ArrayType<FloatType>>;
831
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
832
+ head_configs: OptionType<ArrayType<StructType<{
833
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
834
+ head_type: VariantType<{
835
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
836
+ binary: NullType;
837
+ /** Multi-class output: n_classes logits, softmax, CE loss */
838
+ multiclass: StructType<{
839
+ n_classes: IntegerType;
840
+ }>;
841
+ }>;
842
+ /** Optional class weights for this head */
843
+ class_weights: OptionType<ArrayType<FloatType>>;
844
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
845
+ conditional_on: OptionType<IntegerType>;
846
+ }>>>;
847
+ /** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
848
+ action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
849
+ /** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
850
+ start_timestep: OptionType<IntegerType>;
851
+ }>;
852
+ /**
853
+ * Train with trajectory data for return-conditioned sequence generation.
854
+ *
855
+ * Use with decision_transformer architecture.
856
+ *
857
+ * @param returns - Return per sample (n_samples,) - actual outcome achieved
858
+ * @param states - State matrices: n_samples × (seq_len, state_dim)
859
+ * @param actions - Action matrices: n_samples × (seq_len, action_dim)
860
+ * @param masks - Temporal masks: n_samples × (seq_len,) - valid timesteps
861
+ * @param config - Training configuration with decision_transformer architecture
862
+ * @returns Training result with model blob and metrics
863
+ */
864
+ export declare const lightning_train_trajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
865
+ /** Model architecture */
866
+ architecture: VariantType<{
867
+ /** Simple MLP: input → hidden → output */
868
+ mlp: StructType<{
869
+ /** Hidden layer sizes */
870
+ hidden_layers: ArrayType<IntegerType>;
871
+ }>;
872
+ /** Autoencoder: input → encoder → latent → decoder → output */
873
+ autoencoder: StructType<{
874
+ /** Encoder hidden layer sizes */
875
+ encoder_layers: ArrayType<IntegerType>;
876
+ /** Latent dimension (bottleneck) */
877
+ latent_dim: IntegerType;
878
+ /** Decoder hidden layer sizes */
879
+ decoder_layers: ArrayType<IntegerType>;
880
+ }>;
881
+ /** Conv1D: 1D convolutional autoencoder for temporal patterns */
882
+ conv1d: StructType<{
883
+ /** Number of channels (e.g., additive types) */
884
+ n_channels: IntegerType;
885
+ /** Sequence length (e.g., days) */
886
+ sequence_length: IntegerType;
887
+ /** Conv layer channel sizes */
888
+ conv_channels: ArrayType<IntegerType>;
889
+ /** Kernel size for convolutions (must be odd) */
890
+ kernel_size: IntegerType;
891
+ /** Latent dimension after flattening */
892
+ latent_dim: IntegerType;
893
+ /** Optional condition dimension for conditional generation */
894
+ condition_dim: OptionType<IntegerType>;
895
+ }>;
896
+ /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
897
+ sequential: StructType<{
898
+ /** Number of channels (e.g., additive types) */
899
+ n_channels: IntegerType;
900
+ /** Sequence length (e.g., days) */
901
+ sequence_length: IntegerType;
902
+ /** RNN hidden size */
903
+ hidden_size: IntegerType;
904
+ /** Number of RNN layers */
905
+ n_layers: IntegerType;
906
+ /** Cell type: lstm or gru */
907
+ cell_type: VariantType<{
908
+ lstm: NullType;
909
+ gru: NullType;
910
+ }>;
911
+ /** Latent dimension (from final hidden state) */
912
+ latent_dim: IntegerType;
913
+ /** Bidirectional encoder (decoder is always unidirectional) */
914
+ bidirectional: BooleanType;
915
+ /** Optional condition dimension for conditional generation */
916
+ condition_dim: OptionType<IntegerType>;
917
+ }>;
918
+ /** Transformer: attention-based autoencoder for complex patterns */
919
+ transformer: StructType<{
920
+ /** Number of channels (e.g., additive types) */
921
+ n_channels: IntegerType;
922
+ /** Sequence length (e.g., days) */
923
+ sequence_length: IntegerType;
924
+ /** Model dimension */
925
+ d_model: IntegerType;
926
+ /** Number of attention heads (must divide d_model evenly) */
927
+ n_attention_heads: IntegerType;
928
+ /** Number of transformer layers */
929
+ n_layers: IntegerType;
930
+ /** Feedforward dimension (default: 4 * d_model) */
931
+ d_ff: OptionType<IntegerType>;
932
+ /** Latent dimension (mean pooled output) */
933
+ latent_dim: IntegerType;
934
+ /** Optional condition dimension for conditional generation */
935
+ condition_dim: OptionType<IntegerType>;
936
+ }>;
937
+ /**
938
+ * Decision Transformer: return-conditioned sequence generation.
939
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
940
+ * Predicts actions conditioned on desired return and state history.
941
+ */
942
+ decision_transformer: StructType<{
943
+ /** Sequence length (timesteps) */
944
+ sequence_length: IntegerType;
945
+ /** State dimension per timestep */
946
+ state_dim: IntegerType;
947
+ /** Action dimension per timestep */
948
+ action_dim: IntegerType;
949
+ /** Model dimension (transformer hidden size) */
950
+ d_model: IntegerType;
951
+ /** Number of attention heads */
952
+ n_attention_heads: IntegerType;
953
+ /** Number of transformer layers */
954
+ n_layers: IntegerType;
955
+ /** Feedforward dimension (default: 4 * d_model) */
956
+ d_ff: OptionType<IntegerType>;
957
+ /** Dropout rate */
958
+ dropout: OptionType<FloatType>;
959
+ /** Whether return is per-timestep or global */
960
+ return_embedding: VariantType<{
961
+ /** Single return value for entire sequence */
962
+ global: NullType;
963
+ /** Return-to-go at each timestep */
964
+ per_timestep: NullType;
965
+ }>;
966
+ }>;
967
+ }>;
968
+ /** Output mode (determines loss function) */
969
+ output: VariantType<{
970
+ /** Regression: MSE loss, no activation */
971
+ regression: NullType;
972
+ /** Binary: BCE loss, sigmoid activation */
973
+ binary: StructType<{
974
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
975
+ pos_weight: OptionType<ArrayType<FloatType>>;
976
+ }>;
977
+ /** Multiclass: CrossEntropy loss, softmax activation */
978
+ multiclass: StructType<{
979
+ /** Number of classes */
980
+ n_classes: IntegerType;
981
+ /** Optional per-class weights */
982
+ class_weights: OptionType<ArrayType<FloatType>>;
983
+ }>;
984
+ /** Multi-head categorical: N independent CrossEntropy heads */
985
+ multi_head: StructType<{
986
+ /** Number of heads (e.g., 84 time slots) */
987
+ n_heads: IntegerType;
988
+ /** Classes per head (e.g., 4 bins) */
989
+ n_classes_per_head: IntegerType;
990
+ /** Optional class weights matrix (n_heads, n_classes) */
991
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
992
+ }>;
993
+ /**
994
+ * Mixed output types per head.
995
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
996
+ * Binary heads: 1 logit → sigmoid → BCE loss
997
+ * Multiclass heads: n_classes logits → softmax → CE loss
998
+ * Action vectors use one-hot encoding for multiclass heads.
999
+ */
1000
+ multi_head_mixed: StructType<{
1001
+ /** Array of head configurations */
1002
+ heads: ArrayType<StructType<{
1003
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1004
+ head_type: VariantType<{
1005
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1006
+ binary: NullType;
1007
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1008
+ multiclass: StructType<{
1009
+ n_classes: IntegerType;
1010
+ }>;
1011
+ }>;
1012
+ /** Optional class weights for this head */
1013
+ class_weights: OptionType<ArrayType<FloatType>>;
1014
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1015
+ conditional_on: OptionType<IntegerType>;
1016
+ }>>;
1017
+ }>;
1018
+ }>;
1019
+ /** Learning rate (default: 1e-3) */
1020
+ learning_rate: OptionType<FloatType>;
1021
+ /** Maximum epochs (default: 100) */
1022
+ max_epochs: OptionType<IntegerType>;
1023
+ /** Early stopping patience (default: 10) */
1024
+ patience: OptionType<IntegerType>;
1025
+ /** Batch size (default: 32) */
1026
+ batch_size: OptionType<IntegerType>;
1027
+ /** Dropout rate (default: 0.1) */
1028
+ dropout: OptionType<FloatType>;
1029
+ /** Gradient clipping value (default: 1.0) */
1030
+ gradient_clip: OptionType<FloatType>;
1031
+ /** L2 regularization weight decay (default: 0) */
1032
+ weight_decay: OptionType<FloatType>;
1033
+ /** Random seed for reproducibility */
1034
+ random_state: OptionType<IntegerType>;
1035
+ /** Optional callback called each epoch */
1036
+ epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
1037
+ }>], StructType<{
1038
+ /** Trained model blob */
1039
+ model: VariantType<{
1040
+ lightning: StructType<{
1041
+ /** Serialized model data (state_dict + hparams) */
1042
+ data: BlobType;
1043
+ /** Input dimension */
1044
+ n_features: IntegerType;
1045
+ /** Output dimension */
1046
+ output_dim: IntegerType;
1047
+ /** Architecture type */
1048
+ architecture_type: StringType;
1049
+ /** Output type */
1050
+ output_type: StringType;
1051
+ /** Latent dimension (autoencoder only) */
1052
+ latent_dim: OptionType<IntegerType>;
1053
+ }>;
1054
+ }>;
1055
+ /** Final training loss */
1056
+ train_loss: FloatType;
1057
+ /** Final validation loss */
1058
+ val_loss: FloatType;
1059
+ /** Best epoch (for early stopping) */
1060
+ best_epoch: IntegerType;
1061
+ }>>;
1062
+ /**
1063
+ * Generate action sequences autoregressively from trajectory model.
1064
+ *
1065
+ * Use with models trained via trainTrajectory.
1066
+ *
1067
+ * @param model - Trained model from trainTrajectory
1068
+ * @param states - State matrices: n_samples × (seq_len, state_dim)
1069
+ * @param target_returns - Target returns: (n_samples,)
1070
+ * @param config - Generation configuration
1071
+ * @returns Generated actions: n_samples × (seq_len, action_dim)
1072
+ */
1073
+ export declare const lightning_generate_trajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
1074
+ lightning: StructType<{
1075
+ /** Serialized model data (state_dict + hparams) */
1076
+ data: BlobType;
1077
+ /** Input dimension */
1078
+ n_features: IntegerType;
1079
+ /** Output dimension */
1080
+ output_dim: IntegerType;
1081
+ /** Architecture type */
1082
+ architecture_type: StringType;
1083
+ /** Output type */
1084
+ output_type: StringType;
1085
+ /** Latent dimension (autoencoder only) */
1086
+ latent_dim: OptionType<IntegerType>;
1087
+ }>;
1088
+ }>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
1089
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
1090
+ temperature: FloatType;
1091
+ /** Whether to return probabilities or samples */
1092
+ return_probs: BooleanType;
1093
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
1094
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
1095
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
1096
+ temporal_mask: OptionType<ArrayType<FloatType>>;
1097
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
1098
+ head_configs: OptionType<ArrayType<StructType<{
1099
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1100
+ head_type: VariantType<{
1101
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1102
+ binary: NullType;
1103
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1104
+ multiclass: StructType<{
1105
+ n_classes: IntegerType;
1106
+ }>;
1107
+ }>;
1108
+ /** Optional class weights for this head */
1109
+ class_weights: OptionType<ArrayType<FloatType>>;
1110
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1111
+ conditional_on: OptionType<IntegerType>;
1112
+ }>>>;
1113
+ /** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
1114
+ action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
1115
+ /** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
1116
+ start_timestep: OptionType<IntegerType>;
1117
+ }>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
627
1118
  /**
628
1119
  * Lightning types namespace.
629
1120
  */
@@ -652,6 +1143,31 @@ export declare const LightningTypes: {
652
1143
  /** Optional class weights matrix (n_heads, n_classes) */
653
1144
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
654
1145
  }>;
1146
+ /**
1147
+ * Mixed output types per head.
1148
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1149
+ * Binary heads: 1 logit → sigmoid → BCE loss
1150
+ * Multiclass heads: n_classes logits → softmax → CE loss
1151
+ * Action vectors use one-hot encoding for multiclass heads.
1152
+ */
1153
+ multi_head_mixed: StructType<{
1154
+ /** Array of head configurations */
1155
+ heads: ArrayType<StructType<{
1156
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1157
+ head_type: VariantType<{
1158
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1159
+ binary: NullType;
1160
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1161
+ multiclass: StructType<{
1162
+ n_classes: IntegerType;
1163
+ }>;
1164
+ }>;
1165
+ /** Optional class weights for this head */
1166
+ class_weights: OptionType<ArrayType<FloatType>>;
1167
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1168
+ conditional_on: OptionType<IntegerType>;
1169
+ }>>;
1170
+ }>;
655
1171
  }>;
656
1172
  readonly ArchitectureType: VariantType<{
657
1173
  /** Simple MLP: input → hidden → output */
@@ -724,6 +1240,36 @@ export declare const LightningTypes: {
724
1240
  /** Optional condition dimension for conditional generation */
725
1241
  condition_dim: OptionType<IntegerType>;
726
1242
  }>;
1243
+ /**
1244
+ * Decision Transformer: return-conditioned sequence generation.
1245
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1246
+ * Predicts actions conditioned on desired return and state history.
1247
+ */
1248
+ decision_transformer: StructType<{
1249
+ /** Sequence length (timesteps) */
1250
+ sequence_length: IntegerType;
1251
+ /** State dimension per timestep */
1252
+ state_dim: IntegerType;
1253
+ /** Action dimension per timestep */
1254
+ action_dim: IntegerType;
1255
+ /** Model dimension (transformer hidden size) */
1256
+ d_model: IntegerType;
1257
+ /** Number of attention heads */
1258
+ n_attention_heads: IntegerType;
1259
+ /** Number of transformer layers */
1260
+ n_layers: IntegerType;
1261
+ /** Feedforward dimension (default: 4 * d_model) */
1262
+ d_ff: OptionType<IntegerType>;
1263
+ /** Dropout rate */
1264
+ dropout: OptionType<FloatType>;
1265
+ /** Whether return is per-timestep or global */
1266
+ return_embedding: VariantType<{
1267
+ /** Single return value for entire sequence */
1268
+ global: NullType;
1269
+ /** Return-to-go at each timestep */
1270
+ per_timestep: NullType;
1271
+ }>;
1272
+ }>;
727
1273
  }>;
728
1274
  readonly CellType: VariantType<{
729
1275
  lstm: NullType;
@@ -803,6 +1349,36 @@ export declare const LightningTypes: {
803
1349
  /** Optional condition dimension for conditional generation */
804
1350
  condition_dim: OptionType<IntegerType>;
805
1351
  }>;
1352
+ /**
1353
+ * Decision Transformer: return-conditioned sequence generation.
1354
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1355
+ * Predicts actions conditioned on desired return and state history.
1356
+ */
1357
+ decision_transformer: StructType<{
1358
+ /** Sequence length (timesteps) */
1359
+ sequence_length: IntegerType;
1360
+ /** State dimension per timestep */
1361
+ state_dim: IntegerType;
1362
+ /** Action dimension per timestep */
1363
+ action_dim: IntegerType;
1364
+ /** Model dimension (transformer hidden size) */
1365
+ d_model: IntegerType;
1366
+ /** Number of attention heads */
1367
+ n_attention_heads: IntegerType;
1368
+ /** Number of transformer layers */
1369
+ n_layers: IntegerType;
1370
+ /** Feedforward dimension (default: 4 * d_model) */
1371
+ d_ff: OptionType<IntegerType>;
1372
+ /** Dropout rate */
1373
+ dropout: OptionType<FloatType>;
1374
+ /** Whether return is per-timestep or global */
1375
+ return_embedding: VariantType<{
1376
+ /** Single return value for entire sequence */
1377
+ global: NullType;
1378
+ /** Return-to-go at each timestep */
1379
+ per_timestep: NullType;
1380
+ }>;
1381
+ }>;
806
1382
  }>;
807
1383
  /** Output mode (determines loss function) */
808
1384
  output: VariantType<{
@@ -829,6 +1405,31 @@ export declare const LightningTypes: {
829
1405
  /** Optional class weights matrix (n_heads, n_classes) */
830
1406
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
831
1407
  }>;
1408
+ /**
1409
+ * Mixed output types per head.
1410
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1411
+ * Binary heads: 1 logit → sigmoid → BCE loss
1412
+ * Multiclass heads: n_classes logits → softmax → CE loss
1413
+ * Action vectors use one-hot encoding for multiclass heads.
1414
+ */
1415
+ multi_head_mixed: StructType<{
1416
+ /** Array of head configurations */
1417
+ heads: ArrayType<StructType<{
1418
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1419
+ head_type: VariantType<{
1420
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1421
+ binary: NullType;
1422
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1423
+ multiclass: StructType<{
1424
+ n_classes: IntegerType;
1425
+ }>;
1426
+ }>;
1427
+ /** Optional class weights for this head */
1428
+ class_weights: OptionType<ArrayType<FloatType>>;
1429
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1430
+ conditional_on: OptionType<IntegerType>;
1431
+ }>>;
1432
+ }>;
832
1433
  }>;
833
1434
  /** Learning rate (default: 1e-3) */
834
1435
  learning_rate: OptionType<FloatType>;
@@ -910,6 +1511,57 @@ export declare const LightningTypes: {
910
1511
  /** If true, return probabilities. If false, return samples. */
911
1512
  return_probs: BooleanType;
912
1513
  }>;
1514
+ readonly ReturnEmbeddingType: VariantType<{
1515
+ /** Single return value for entire sequence */
1516
+ global: NullType;
1517
+ /** Return-to-go at each timestep */
1518
+ per_timestep: NullType;
1519
+ }>;
1520
+ readonly HeadConfigType: StructType<{
1521
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1522
+ head_type: VariantType<{
1523
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1524
+ binary: NullType;
1525
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1526
+ multiclass: StructType<{
1527
+ n_classes: IntegerType;
1528
+ }>;
1529
+ }>;
1530
+ /** Optional class weights for this head */
1531
+ class_weights: OptionType<ArrayType<FloatType>>;
1532
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1533
+ conditional_on: OptionType<IntegerType>;
1534
+ }>;
1535
+ readonly TrajectoryGenerateConfigType: StructType<{
1536
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
1537
+ temperature: FloatType;
1538
+ /** Whether to return probabilities or samples */
1539
+ return_probs: BooleanType;
1540
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
1541
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
1542
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
1543
+ temporal_mask: OptionType<ArrayType<FloatType>>;
1544
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
1545
+ head_configs: OptionType<ArrayType<StructType<{
1546
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1547
+ head_type: VariantType<{
1548
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1549
+ binary: NullType;
1550
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1551
+ multiclass: StructType<{
1552
+ n_classes: IntegerType;
1553
+ }>;
1554
+ }>;
1555
+ /** Optional class weights for this head */
1556
+ class_weights: OptionType<ArrayType<FloatType>>;
1557
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1558
+ conditional_on: OptionType<IntegerType>;
1559
+ }>>>;
1560
+ /** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
1561
+ action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
1562
+ /** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
1563
+ start_timestep: OptionType<IntegerType>;
1564
+ }>;
913
1565
  };
914
1566
  /**
915
1567
  * Lightning platform functions namespace.
@@ -1015,6 +1667,36 @@ export declare const Lightning: {
1015
1667
  /** Optional condition dimension for conditional generation */
1016
1668
  condition_dim: OptionType<IntegerType>;
1017
1669
  }>;
1670
+ /**
1671
+ * Decision Transformer: return-conditioned sequence generation.
1672
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1673
+ * Predicts actions conditioned on desired return and state history.
1674
+ */
1675
+ decision_transformer: StructType<{
1676
+ /** Sequence length (timesteps) */
1677
+ sequence_length: IntegerType;
1678
+ /** State dimension per timestep */
1679
+ state_dim: IntegerType;
1680
+ /** Action dimension per timestep */
1681
+ action_dim: IntegerType;
1682
+ /** Model dimension (transformer hidden size) */
1683
+ d_model: IntegerType;
1684
+ /** Number of attention heads */
1685
+ n_attention_heads: IntegerType;
1686
+ /** Number of transformer layers */
1687
+ n_layers: IntegerType;
1688
+ /** Feedforward dimension (default: 4 * d_model) */
1689
+ d_ff: OptionType<IntegerType>;
1690
+ /** Dropout rate */
1691
+ dropout: OptionType<FloatType>;
1692
+ /** Whether return is per-timestep or global */
1693
+ return_embedding: VariantType<{
1694
+ /** Single return value for entire sequence */
1695
+ global: NullType;
1696
+ /** Return-to-go at each timestep */
1697
+ per_timestep: NullType;
1698
+ }>;
1699
+ }>;
1018
1700
  }>;
1019
1701
  /** Output mode (determines loss function) */
1020
1702
  output: VariantType<{
@@ -1041,6 +1723,31 @@ export declare const Lightning: {
1041
1723
  /** Optional class weights matrix (n_heads, n_classes) */
1042
1724
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1043
1725
  }>;
1726
+ /**
1727
+ * Mixed output types per head.
1728
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1729
+ * Binary heads: 1 logit → sigmoid → BCE loss
1730
+ * Multiclass heads: n_classes logits → softmax → CE loss
1731
+ * Action vectors use one-hot encoding for multiclass heads.
1732
+ */
1733
+ multi_head_mixed: StructType<{
1734
+ /** Array of head configurations */
1735
+ heads: ArrayType<StructType<{
1736
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1737
+ head_type: VariantType<{
1738
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1739
+ binary: NullType;
1740
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1741
+ multiclass: StructType<{
1742
+ n_classes: IntegerType;
1743
+ }>;
1744
+ }>;
1745
+ /** Optional class weights for this head */
1746
+ class_weights: OptionType<ArrayType<FloatType>>;
1747
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1748
+ conditional_on: OptionType<IntegerType>;
1749
+ }>>;
1750
+ }>;
1044
1751
  }>;
1045
1752
  /** Learning rate (default: 1e-3) */
1046
1753
  learning_rate: OptionType<FloatType>;
@@ -1210,6 +1917,283 @@ export declare const Lightning: {
1210
1917
  /** If true, return probabilities. If false, return samples. */
1211
1918
  return_probs: BooleanType;
1212
1919
  }>], ArrayType<ArrayType<FloatType>>>;
1920
+ /**
1921
+ * Train a Decision Transformer with trajectory data.
1922
+ *
1923
+ * Trains a return-conditioned sequence generation model that learns
1924
+ * to predict actions given states and desired returns.
1925
+ *
1926
+ * @example
1927
+ * ```typescript
1928
+ * const result = Lightning.trainTrajectory(
1929
+ * returns, states, actions, masks,
1930
+ * {
1931
+ * architecture: variant("decision_transformer", {
1932
+ * sequence_length: 14n,
1933
+ * state_dim: 8n,
1934
+ * action_dim: 11n,
1935
+ * d_model: 64n,
1936
+ * n_attention_heads: 4n,
1937
+ * n_layers: 3n,
1938
+ * d_ff: variant("none", null),
1939
+ * dropout: variant("some", 0.1),
1940
+ * return_embedding: variant("global", null),
1941
+ * }),
1942
+ * output: variant("multi_head_mixed", { heads: [...] }),
1943
+ * ...
1944
+ * }
1945
+ * );
1946
+ * ```
1947
+ */
1948
+ readonly trainTrajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
1949
+ /** Model architecture */
1950
+ architecture: VariantType<{
1951
+ /** Simple MLP: input → hidden → output */
1952
+ mlp: StructType<{
1953
+ /** Hidden layer sizes */
1954
+ hidden_layers: ArrayType<IntegerType>;
1955
+ }>;
1956
+ /** Autoencoder: input → encoder → latent → decoder → output */
1957
+ autoencoder: StructType<{
1958
+ /** Encoder hidden layer sizes */
1959
+ encoder_layers: ArrayType<IntegerType>;
1960
+ /** Latent dimension (bottleneck) */
1961
+ latent_dim: IntegerType;
1962
+ /** Decoder hidden layer sizes */
1963
+ decoder_layers: ArrayType<IntegerType>;
1964
+ }>;
1965
+ /** Conv1D: 1D convolutional autoencoder for temporal patterns */
1966
+ conv1d: StructType<{
1967
+ /** Number of channels (e.g., additive types) */
1968
+ n_channels: IntegerType;
1969
+ /** Sequence length (e.g., days) */
1970
+ sequence_length: IntegerType;
1971
+ /** Conv layer channel sizes */
1972
+ conv_channels: ArrayType<IntegerType>;
1973
+ /** Kernel size for convolutions (must be odd) */
1974
+ kernel_size: IntegerType;
1975
+ /** Latent dimension after flattening */
1976
+ latent_dim: IntegerType;
1977
+ /** Optional condition dimension for conditional generation */
1978
+ condition_dim: OptionType<IntegerType>;
1979
+ }>;
1980
+ /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
1981
+ sequential: StructType<{
1982
+ /** Number of channels (e.g., additive types) */
1983
+ n_channels: IntegerType;
1984
+ /** Sequence length (e.g., days) */
1985
+ sequence_length: IntegerType;
1986
+ /** RNN hidden size */
1987
+ hidden_size: IntegerType;
1988
+ /** Number of RNN layers */
1989
+ n_layers: IntegerType;
1990
+ /** Cell type: lstm or gru */
1991
+ cell_type: VariantType<{
1992
+ lstm: NullType;
1993
+ gru: NullType;
1994
+ }>;
1995
+ /** Latent dimension (from final hidden state) */
1996
+ latent_dim: IntegerType;
1997
+ /** Bidirectional encoder (decoder is always unidirectional) */
1998
+ bidirectional: BooleanType;
1999
+ /** Optional condition dimension for conditional generation */
2000
+ condition_dim: OptionType<IntegerType>;
2001
+ }>;
2002
+ /** Transformer: attention-based autoencoder for complex patterns */
2003
+ transformer: StructType<{
2004
+ /** Number of channels (e.g., additive types) */
2005
+ n_channels: IntegerType;
2006
+ /** Sequence length (e.g., days) */
2007
+ sequence_length: IntegerType;
2008
+ /** Model dimension */
2009
+ d_model: IntegerType;
2010
+ /** Number of attention heads (must divide d_model evenly) */
2011
+ n_attention_heads: IntegerType;
2012
+ /** Number of transformer layers */
2013
+ n_layers: IntegerType;
2014
+ /** Feedforward dimension (default: 4 * d_model) */
2015
+ d_ff: OptionType<IntegerType>;
2016
+ /** Latent dimension (mean pooled output) */
2017
+ latent_dim: IntegerType;
2018
+ /** Optional condition dimension for conditional generation */
2019
+ condition_dim: OptionType<IntegerType>;
2020
+ }>;
2021
+ /**
2022
+ * Decision Transformer: return-conditioned sequence generation.
2023
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2024
+ * Predicts actions conditioned on desired return and state history.
2025
+ */
2026
+ decision_transformer: StructType<{
2027
+ /** Sequence length (timesteps) */
2028
+ sequence_length: IntegerType;
2029
+ /** State dimension per timestep */
2030
+ state_dim: IntegerType;
2031
+ /** Action dimension per timestep */
2032
+ action_dim: IntegerType;
2033
+ /** Model dimension (transformer hidden size) */
2034
+ d_model: IntegerType;
2035
+ /** Number of attention heads */
2036
+ n_attention_heads: IntegerType;
2037
+ /** Number of transformer layers */
2038
+ n_layers: IntegerType;
2039
+ /** Feedforward dimension (default: 4 * d_model) */
2040
+ d_ff: OptionType<IntegerType>;
2041
+ /** Dropout rate */
2042
+ dropout: OptionType<FloatType>;
2043
+ /** Whether return is per-timestep or global */
2044
+ return_embedding: VariantType<{
2045
+ /** Single return value for entire sequence */
2046
+ global: NullType;
2047
+ /** Return-to-go at each timestep */
2048
+ per_timestep: NullType;
2049
+ }>;
2050
+ }>;
2051
+ }>;
2052
+ /** Output mode (determines loss function) */
2053
+ output: VariantType<{
2054
+ /** Regression: MSE loss, no activation */
2055
+ regression: NullType;
2056
+ /** Binary: BCE loss, sigmoid activation */
2057
+ binary: StructType<{
2058
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
2059
+ pos_weight: OptionType<ArrayType<FloatType>>;
2060
+ }>;
2061
+ /** Multiclass: CrossEntropy loss, softmax activation */
2062
+ multiclass: StructType<{
2063
+ /** Number of classes */
2064
+ n_classes: IntegerType;
2065
+ /** Optional per-class weights */
2066
+ class_weights: OptionType<ArrayType<FloatType>>;
2067
+ }>;
2068
+ /** Multi-head categorical: N independent CrossEntropy heads */
2069
+ multi_head: StructType<{
2070
+ /** Number of heads (e.g., 84 time slots) */
2071
+ n_heads: IntegerType;
2072
+ /** Classes per head (e.g., 4 bins) */
2073
+ n_classes_per_head: IntegerType;
2074
+ /** Optional class weights matrix (n_heads, n_classes) */
2075
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
2076
+ }>;
2077
+ /**
2078
+ * Mixed output types per head.
2079
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2080
+ * Binary heads: 1 logit → sigmoid → BCE loss
2081
+ * Multiclass heads: n_classes logits → softmax → CE loss
2082
+ * Action vectors use one-hot encoding for multiclass heads.
2083
+ */
2084
+ multi_head_mixed: StructType<{
2085
+ /** Array of head configurations */
2086
+ heads: ArrayType<StructType<{
2087
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2088
+ head_type: VariantType<{
2089
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2090
+ binary: NullType;
2091
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2092
+ multiclass: StructType<{
2093
+ n_classes: IntegerType;
2094
+ }>;
2095
+ }>;
2096
+ /** Optional class weights for this head */
2097
+ class_weights: OptionType<ArrayType<FloatType>>;
2098
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2099
+ conditional_on: OptionType<IntegerType>;
2100
+ }>>;
2101
+ }>;
2102
+ }>;
2103
+ /** Learning rate (default: 1e-3) */
2104
+ learning_rate: OptionType<FloatType>;
2105
+ /** Maximum epochs (default: 100) */
2106
+ max_epochs: OptionType<IntegerType>;
2107
+ /** Early stopping patience (default: 10) */
2108
+ patience: OptionType<IntegerType>;
2109
+ /** Batch size (default: 32) */
2110
+ batch_size: OptionType<IntegerType>;
2111
+ /** Dropout rate (default: 0.1) */
2112
+ dropout: OptionType<FloatType>;
2113
+ /** Gradient clipping value (default: 1.0) */
2114
+ gradient_clip: OptionType<FloatType>;
2115
+ /** L2 regularization weight decay (default: 0) */
2116
+ weight_decay: OptionType<FloatType>;
2117
+ /** Random seed for reproducibility */
2118
+ random_state: OptionType<IntegerType>;
2119
+ /** Optional callback called each epoch */
2120
+ epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
2121
+ }>], StructType<{
2122
+ /** Trained model blob */
2123
+ model: VariantType<{
2124
+ lightning: StructType<{
2125
+ /** Serialized model data (state_dict + hparams) */
2126
+ data: BlobType;
2127
+ /** Input dimension */
2128
+ n_features: IntegerType;
2129
+ /** Output dimension */
2130
+ output_dim: IntegerType;
2131
+ /** Architecture type */
2132
+ architecture_type: StringType;
2133
+ /** Output type */
2134
+ output_type: StringType;
2135
+ /** Latent dimension (autoencoder only) */
2136
+ latent_dim: OptionType<IntegerType>;
2137
+ }>;
2138
+ }>;
2139
+ /** Final training loss */
2140
+ train_loss: FloatType;
2141
+ /** Final validation loss */
2142
+ val_loss: FloatType;
2143
+ /** Best epoch (for early stopping) */
2144
+ best_epoch: IntegerType;
2145
+ }>>;
2146
+ /**
2147
+ * Generate action sequences from a Decision Transformer.
2148
+ *
2149
+ * Autoregressively generates actions conditioned on target returns
2150
+ * and state sequences.
2151
+ */
2152
+ readonly generateTrajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
2153
+ lightning: StructType<{
2154
+ /** Serialized model data (state_dict + hparams) */
2155
+ data: BlobType;
2156
+ /** Input dimension */
2157
+ n_features: IntegerType;
2158
+ /** Output dimension */
2159
+ output_dim: IntegerType;
2160
+ /** Architecture type */
2161
+ architecture_type: StringType;
2162
+ /** Output type */
2163
+ output_type: StringType;
2164
+ /** Latent dimension (autoencoder only) */
2165
+ latent_dim: OptionType<IntegerType>;
2166
+ }>;
2167
+ }>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
2168
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
2169
+ temperature: FloatType;
2170
+ /** Whether to return probabilities or samples */
2171
+ return_probs: BooleanType;
2172
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
2173
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
2174
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
2175
+ temporal_mask: OptionType<ArrayType<FloatType>>;
2176
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
2177
+ head_configs: OptionType<ArrayType<StructType<{
2178
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2179
+ head_type: VariantType<{
2180
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2181
+ binary: NullType;
2182
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2183
+ multiclass: StructType<{
2184
+ n_classes: IntegerType;
2185
+ }>;
2186
+ }>;
2187
+ /** Optional class weights for this head */
2188
+ class_weights: OptionType<ArrayType<FloatType>>;
2189
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2190
+ conditional_on: OptionType<IntegerType>;
2191
+ }>>>;
2192
+ /** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
2193
+ action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
2194
+ /** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
2195
+ start_timestep: OptionType<IntegerType>;
2196
+ }>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
1213
2197
  /**
1214
2198
  * Type definitions for Lightning functions.
1215
2199
  */
@@ -1238,6 +2222,31 @@ export declare const Lightning: {
1238
2222
  /** Optional class weights matrix (n_heads, n_classes) */
1239
2223
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1240
2224
  }>;
2225
+ /**
2226
+ * Mixed output types per head.
2227
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2228
+ * Binary heads: 1 logit → sigmoid → BCE loss
2229
+ * Multiclass heads: n_classes logits → softmax → CE loss
2230
+ * Action vectors use one-hot encoding for multiclass heads.
2231
+ */
2232
+ multi_head_mixed: StructType<{
2233
+ /** Array of head configurations */
2234
+ heads: ArrayType<StructType<{
2235
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2236
+ head_type: VariantType<{
2237
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2238
+ binary: NullType;
2239
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2240
+ multiclass: StructType<{
2241
+ n_classes: IntegerType;
2242
+ }>;
2243
+ }>;
2244
+ /** Optional class weights for this head */
2245
+ class_weights: OptionType<ArrayType<FloatType>>;
2246
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2247
+ conditional_on: OptionType<IntegerType>;
2248
+ }>>;
2249
+ }>;
1241
2250
  }>;
1242
2251
  readonly ArchitectureType: VariantType<{
1243
2252
  /** Simple MLP: input → hidden → output */
@@ -1310,6 +2319,36 @@ export declare const Lightning: {
1310
2319
  /** Optional condition dimension for conditional generation */
1311
2320
  condition_dim: OptionType<IntegerType>;
1312
2321
  }>;
2322
+ /**
2323
+ * Decision Transformer: return-conditioned sequence generation.
2324
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2325
+ * Predicts actions conditioned on desired return and state history.
2326
+ */
2327
+ decision_transformer: StructType<{
2328
+ /** Sequence length (timesteps) */
2329
+ sequence_length: IntegerType;
2330
+ /** State dimension per timestep */
2331
+ state_dim: IntegerType;
2332
+ /** Action dimension per timestep */
2333
+ action_dim: IntegerType;
2334
+ /** Model dimension (transformer hidden size) */
2335
+ d_model: IntegerType;
2336
+ /** Number of attention heads */
2337
+ n_attention_heads: IntegerType;
2338
+ /** Number of transformer layers */
2339
+ n_layers: IntegerType;
2340
+ /** Feedforward dimension (default: 4 * d_model) */
2341
+ d_ff: OptionType<IntegerType>;
2342
+ /** Dropout rate */
2343
+ dropout: OptionType<FloatType>;
2344
+ /** Whether return is per-timestep or global */
2345
+ return_embedding: VariantType<{
2346
+ /** Single return value for entire sequence */
2347
+ global: NullType;
2348
+ /** Return-to-go at each timestep */
2349
+ per_timestep: NullType;
2350
+ }>;
2351
+ }>;
1313
2352
  }>;
1314
2353
  readonly CellType: VariantType<{
1315
2354
  lstm: NullType;
@@ -1389,6 +2428,36 @@ export declare const Lightning: {
1389
2428
  /** Optional condition dimension for conditional generation */
1390
2429
  condition_dim: OptionType<IntegerType>;
1391
2430
  }>;
2431
+ /**
2432
+ * Decision Transformer: return-conditioned sequence generation.
2433
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2434
+ * Predicts actions conditioned on desired return and state history.
2435
+ */
2436
+ decision_transformer: StructType<{
2437
+ /** Sequence length (timesteps) */
2438
+ sequence_length: IntegerType;
2439
+ /** State dimension per timestep */
2440
+ state_dim: IntegerType;
2441
+ /** Action dimension per timestep */
2442
+ action_dim: IntegerType;
2443
+ /** Model dimension (transformer hidden size) */
2444
+ d_model: IntegerType;
2445
+ /** Number of attention heads */
2446
+ n_attention_heads: IntegerType;
2447
+ /** Number of transformer layers */
2448
+ n_layers: IntegerType;
2449
+ /** Feedforward dimension (default: 4 * d_model) */
2450
+ d_ff: OptionType<IntegerType>;
2451
+ /** Dropout rate */
2452
+ dropout: OptionType<FloatType>;
2453
+ /** Whether return is per-timestep or global */
2454
+ return_embedding: VariantType<{
2455
+ /** Single return value for entire sequence */
2456
+ global: NullType;
2457
+ /** Return-to-go at each timestep */
2458
+ per_timestep: NullType;
2459
+ }>;
2460
+ }>;
1392
2461
  }>;
1393
2462
  /** Output mode (determines loss function) */
1394
2463
  output: VariantType<{
@@ -1415,6 +2484,31 @@ export declare const Lightning: {
1415
2484
  /** Optional class weights matrix (n_heads, n_classes) */
1416
2485
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1417
2486
  }>;
2487
+ /**
2488
+ * Mixed output types per head.
2489
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2490
+ * Binary heads: 1 logit → sigmoid → BCE loss
2491
+ * Multiclass heads: n_classes logits → softmax → CE loss
2492
+ * Action vectors use one-hot encoding for multiclass heads.
2493
+ */
2494
+ multi_head_mixed: StructType<{
2495
+ /** Array of head configurations */
2496
+ heads: ArrayType<StructType<{
2497
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2498
+ head_type: VariantType<{
2499
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2500
+ binary: NullType;
2501
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2502
+ multiclass: StructType<{
2503
+ n_classes: IntegerType;
2504
+ }>;
2505
+ }>;
2506
+ /** Optional class weights for this head */
2507
+ class_weights: OptionType<ArrayType<FloatType>>;
2508
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2509
+ conditional_on: OptionType<IntegerType>;
2510
+ }>>;
2511
+ }>;
1418
2512
  }>;
1419
2513
  /** Learning rate (default: 1e-3) */
1420
2514
  learning_rate: OptionType<FloatType>;
@@ -1496,6 +2590,57 @@ export declare const Lightning: {
1496
2590
  /** If true, return probabilities. If false, return samples. */
1497
2591
  return_probs: BooleanType;
1498
2592
  }>;
2593
+ readonly ReturnEmbeddingType: VariantType<{
2594
+ /** Single return value for entire sequence */
2595
+ global: NullType;
2596
+ /** Return-to-go at each timestep */
2597
+ per_timestep: NullType;
2598
+ }>;
2599
+ readonly HeadConfigType: StructType<{
2600
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2601
+ head_type: VariantType<{
2602
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2603
+ binary: NullType;
2604
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2605
+ multiclass: StructType<{
2606
+ n_classes: IntegerType;
2607
+ }>;
2608
+ }>;
2609
+ /** Optional class weights for this head */
2610
+ class_weights: OptionType<ArrayType<FloatType>>;
2611
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2612
+ conditional_on: OptionType<IntegerType>;
2613
+ }>;
2614
+ readonly TrajectoryGenerateConfigType: StructType<{
2615
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
2616
+ temperature: FloatType;
2617
+ /** Whether to return probabilities or samples */
2618
+ return_probs: BooleanType;
2619
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
2620
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
2621
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
2622
+ temporal_mask: OptionType<ArrayType<FloatType>>;
2623
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
2624
+ head_configs: OptionType<ArrayType<StructType<{
2625
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2626
+ head_type: VariantType<{
2627
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2628
+ binary: NullType;
2629
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2630
+ multiclass: StructType<{
2631
+ n_classes: IntegerType;
2632
+ }>;
2633
+ }>;
2634
+ /** Optional class weights for this head */
2635
+ class_weights: OptionType<ArrayType<FloatType>>;
2636
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2637
+ conditional_on: OptionType<IntegerType>;
2638
+ }>>>;
2639
+ /** Optional action prefix: (seq_len, action_dim) - known actions for timesteps 0..start_timestep-1 */
2640
+ action_prefix: OptionType<ArrayType<ArrayType<FloatType>>>;
2641
+ /** Timestep to start generation from (0 = generate all, 5 = use prefix for 0-4, generate 5+) */
2642
+ start_timestep: OptionType<IntegerType>;
2643
+ }>;
1499
2644
  };
1500
2645
  };
1501
2646
  //# sourceMappingURL=lightning.d.ts.map