@elaraai/east-py-datascience 0.0.2-beta.31 → 0.0.2-beta.33

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -13,6 +13,33 @@
13
13
  */
14
14
  import { StructType, VariantType, OptionType, IntegerType, FloatType, BlobType, ArrayType, NullType, BooleanType, FunctionType, StringType } from "@elaraai/east";
15
15
  export { VectorType, MatrixType } from "../types.js";
16
+ /**
17
+ * Return embedding mode for Decision Transformer.
18
+ */
19
+ export declare const ReturnEmbeddingType: VariantType<{
20
+ /** Single return value for entire sequence */
21
+ global: NullType;
22
+ /** Return-to-go at each timestep */
23
+ per_timestep: NullType;
24
+ }>;
25
+ /**
26
+ * Per-head output configuration for multi_head_mixed.
27
+ */
28
+ export declare const HeadConfigType: StructType<{
29
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
30
+ head_type: VariantType<{
31
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
32
+ binary: NullType;
33
+ /** Multi-class output: n_classes logits, softmax, CE loss */
34
+ multiclass: StructType<{
35
+ n_classes: IntegerType;
36
+ }>;
37
+ }>;
38
+ /** Optional class weights for this head */
39
+ class_weights: OptionType<ArrayType<FloatType>>;
40
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
41
+ conditional_on: OptionType<IntegerType>;
42
+ }>;
16
43
  /**
17
44
  * Lightning output mode - determines loss function and output activation.
18
45
  */
@@ -40,6 +67,31 @@ export declare const LightningOutputType: VariantType<{
40
67
  /** Optional class weights matrix (n_heads, n_classes) */
41
68
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
42
69
  }>;
70
+ /**
71
+ * Mixed output types per head.
72
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
73
+ * Binary heads: 1 logit → sigmoid → BCE loss
74
+ * Multiclass heads: n_classes logits → softmax → CE loss
75
+ * Action vectors use one-hot encoding for multiclass heads.
76
+ */
77
+ multi_head_mixed: StructType<{
78
+ /** Array of head configurations */
79
+ heads: ArrayType<StructType<{
80
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
81
+ head_type: VariantType<{
82
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
83
+ binary: NullType;
84
+ /** Multi-class output: n_classes logits, softmax, CE loss */
85
+ multiclass: StructType<{
86
+ n_classes: IntegerType;
87
+ }>;
88
+ }>;
89
+ /** Optional class weights for this head */
90
+ class_weights: OptionType<ArrayType<FloatType>>;
91
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
92
+ conditional_on: OptionType<IntegerType>;
93
+ }>>;
94
+ }>;
43
95
  }>;
44
96
  /**
45
97
  * Cell type for sequential architectures.
@@ -122,6 +174,36 @@ export declare const LightningArchitectureType: VariantType<{
122
174
  /** Optional condition dimension for conditional generation */
123
175
  condition_dim: OptionType<IntegerType>;
124
176
  }>;
177
+ /**
178
+ * Decision Transformer: return-conditioned sequence generation.
179
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
180
+ * Predicts actions conditioned on desired return and state history.
181
+ */
182
+ decision_transformer: StructType<{
183
+ /** Sequence length (timesteps) */
184
+ sequence_length: IntegerType;
185
+ /** State dimension per timestep */
186
+ state_dim: IntegerType;
187
+ /** Action dimension per timestep */
188
+ action_dim: IntegerType;
189
+ /** Model dimension (transformer hidden size) */
190
+ d_model: IntegerType;
191
+ /** Number of attention heads */
192
+ n_attention_heads: IntegerType;
193
+ /** Number of transformer layers */
194
+ n_layers: IntegerType;
195
+ /** Feedforward dimension (default: 4 * d_model) */
196
+ d_ff: OptionType<IntegerType>;
197
+ /** Dropout rate */
198
+ dropout: OptionType<FloatType>;
199
+ /** Whether return is per-timestep or global */
200
+ return_embedding: VariantType<{
201
+ /** Single return value for entire sequence */
202
+ global: NullType;
203
+ /** Return-to-go at each timestep */
204
+ per_timestep: NullType;
205
+ }>;
206
+ }>;
125
207
  }>;
126
208
  /**
127
209
  * Epoch callback function type: (epoch, train_loss, val_loss) -> void
@@ -203,6 +285,36 @@ export declare const LightningConfigType: StructType<{
203
285
  /** Optional condition dimension for conditional generation */
204
286
  condition_dim: OptionType<IntegerType>;
205
287
  }>;
288
+ /**
289
+ * Decision Transformer: return-conditioned sequence generation.
290
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
291
+ * Predicts actions conditioned on desired return and state history.
292
+ */
293
+ decision_transformer: StructType<{
294
+ /** Sequence length (timesteps) */
295
+ sequence_length: IntegerType;
296
+ /** State dimension per timestep */
297
+ state_dim: IntegerType;
298
+ /** Action dimension per timestep */
299
+ action_dim: IntegerType;
300
+ /** Model dimension (transformer hidden size) */
301
+ d_model: IntegerType;
302
+ /** Number of attention heads */
303
+ n_attention_heads: IntegerType;
304
+ /** Number of transformer layers */
305
+ n_layers: IntegerType;
306
+ /** Feedforward dimension (default: 4 * d_model) */
307
+ d_ff: OptionType<IntegerType>;
308
+ /** Dropout rate */
309
+ dropout: OptionType<FloatType>;
310
+ /** Whether return is per-timestep or global */
311
+ return_embedding: VariantType<{
312
+ /** Single return value for entire sequence */
313
+ global: NullType;
314
+ /** Return-to-go at each timestep */
315
+ per_timestep: NullType;
316
+ }>;
317
+ }>;
206
318
  }>;
207
319
  /** Output mode (determines loss function) */
208
320
  output: VariantType<{
@@ -229,6 +341,31 @@ export declare const LightningConfigType: StructType<{
229
341
  /** Optional class weights matrix (n_heads, n_classes) */
230
342
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
231
343
  }>;
344
+ /**
345
+ * Mixed output types per head.
346
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
347
+ * Binary heads: 1 logit → sigmoid → BCE loss
348
+ * Multiclass heads: n_classes logits → softmax → CE loss
349
+ * Action vectors use one-hot encoding for multiclass heads.
350
+ */
351
+ multi_head_mixed: StructType<{
352
+ /** Array of head configurations */
353
+ heads: ArrayType<StructType<{
354
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
355
+ head_type: VariantType<{
356
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
357
+ binary: NullType;
358
+ /** Multi-class output: n_classes logits, softmax, CE loss */
359
+ multiclass: StructType<{
360
+ n_classes: IntegerType;
361
+ }>;
362
+ }>;
363
+ /** Optional class weights for this head */
364
+ class_weights: OptionType<ArrayType<FloatType>>;
365
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
366
+ conditional_on: OptionType<IntegerType>;
367
+ }>>;
368
+ }>;
232
369
  }>;
233
370
  /** Learning rate (default: 1e-3) */
234
371
  learning_rate: OptionType<FloatType>;
@@ -401,6 +538,36 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
401
538
  /** Optional condition dimension for conditional generation */
402
539
  condition_dim: OptionType<IntegerType>;
403
540
  }>;
541
+ /**
542
+ * Decision Transformer: return-conditioned sequence generation.
543
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
544
+ * Predicts actions conditioned on desired return and state history.
545
+ */
546
+ decision_transformer: StructType<{
547
+ /** Sequence length (timesteps) */
548
+ sequence_length: IntegerType;
549
+ /** State dimension per timestep */
550
+ state_dim: IntegerType;
551
+ /** Action dimension per timestep */
552
+ action_dim: IntegerType;
553
+ /** Model dimension (transformer hidden size) */
554
+ d_model: IntegerType;
555
+ /** Number of attention heads */
556
+ n_attention_heads: IntegerType;
557
+ /** Number of transformer layers */
558
+ n_layers: IntegerType;
559
+ /** Feedforward dimension (default: 4 * d_model) */
560
+ d_ff: OptionType<IntegerType>;
561
+ /** Dropout rate */
562
+ dropout: OptionType<FloatType>;
563
+ /** Whether return is per-timestep or global */
564
+ return_embedding: VariantType<{
565
+ /** Single return value for entire sequence */
566
+ global: NullType;
567
+ /** Return-to-go at each timestep */
568
+ per_timestep: NullType;
569
+ }>;
570
+ }>;
404
571
  }>;
405
572
  /** Output mode (determines loss function) */
406
573
  output: VariantType<{
@@ -427,6 +594,31 @@ export declare const lightning_train: import("@elaraai/east").PlatformDefinition
427
594
  /** Optional class weights matrix (n_heads, n_classes) */
428
595
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
429
596
  }>;
597
+ /**
598
+ * Mixed output types per head.
599
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
600
+ * Binary heads: 1 logit → sigmoid → BCE loss
601
+ * Multiclass heads: n_classes logits → softmax → CE loss
602
+ * Action vectors use one-hot encoding for multiclass heads.
603
+ */
604
+ multi_head_mixed: StructType<{
605
+ /** Array of head configurations */
606
+ heads: ArrayType<StructType<{
607
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
608
+ head_type: VariantType<{
609
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
610
+ binary: NullType;
611
+ /** Multi-class output: n_classes logits, softmax, CE loss */
612
+ multiclass: StructType<{
613
+ n_classes: IntegerType;
614
+ }>;
615
+ }>;
616
+ /** Optional class weights for this head */
617
+ class_weights: OptionType<ArrayType<FloatType>>;
618
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
619
+ conditional_on: OptionType<IntegerType>;
620
+ }>>;
621
+ }>;
430
622
  }>;
431
623
  /** Learning rate (default: 1e-3) */
432
624
  learning_rate: OptionType<FloatType>;
@@ -577,35 +769,97 @@ export declare const lightning_decode_conditional: import("@elaraai/east").Platf
577
769
  }>;
578
770
  }>, ArrayType<ArrayType<FloatType>>, ArrayType<ArrayType<FloatType>>], ArrayType<ArrayType<FloatType>>>;
579
771
  /**
580
- * Lightning types namespace.
772
+ * Configuration for autoregressive sequence generation.
581
773
  */
582
- export declare const LightningTypes: {
583
- readonly OutputType: VariantType<{
584
- /** Regression: MSE loss, no activation */
585
- regression: NullType;
586
- /** Binary: BCE loss, sigmoid activation */
587
- binary: StructType<{
588
- /** Optional per-position pos_weights for class imbalance [output_dim] */
589
- pos_weight: OptionType<ArrayType<FloatType>>;
590
- }>;
591
- /** Multiclass: CrossEntropy loss, softmax activation */
592
- multiclass: StructType<{
593
- /** Number of classes */
594
- n_classes: IntegerType;
595
- /** Optional per-class weights */
596
- class_weights: OptionType<ArrayType<FloatType>>;
597
- }>;
598
- /** Multi-head categorical: N independent CrossEntropy heads */
599
- multi_head: StructType<{
600
- /** Number of heads (e.g., 84 time slots) */
601
- n_heads: IntegerType;
602
- /** Classes per head (e.g., 4 bins) */
603
- n_classes_per_head: IntegerType;
604
- /** Optional class weights matrix (n_heads, n_classes) */
605
- class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
606
- }>;
774
+ export declare const LightningGenerateConfigType: StructType<{
775
+ /** Number of steps to generate */
776
+ n_steps: IntegerType;
777
+ /** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
778
+ temperature: FloatType;
779
+ /** If true, return probabilities. If false, return samples. */
780
+ return_probs: BooleanType;
781
+ }>;
782
+ /**
783
+ * Generate sequence autoregressively from a sequential model.
784
+ *
785
+ * Shapes:
786
+ * - prefix: (n_prefix_steps, n_channels) - partial history to continue from, can be empty []
787
+ * - condition: (1, condition_dim) - conditioning features, or none
788
+ * - returns: (n_steps, n_channels) - generated timesteps only (not including prefix)
789
+ *
790
+ * @param model - Trained sequential model blob
791
+ * @param prefix - Partial history to continue from
792
+ * @param condition - Optional conditioning features
793
+ * @param config - Generation configuration
794
+ * @returns Generated sequence matrix
795
+ */
796
+ export declare const lightning_generate_sequence: import("@elaraai/east").PlatformDefinition<[VariantType<{
797
+ lightning: StructType<{
798
+ /** Serialized model data (state_dict + hparams) */
799
+ data: BlobType;
800
+ /** Input dimension */
801
+ n_features: IntegerType;
802
+ /** Output dimension */
803
+ output_dim: IntegerType;
804
+ /** Architecture type */
805
+ architecture_type: StringType;
806
+ /** Output type */
807
+ output_type: StringType;
808
+ /** Latent dimension (autoencoder only) */
809
+ latent_dim: OptionType<IntegerType>;
607
810
  }>;
608
- readonly ArchitectureType: VariantType<{
811
+ }>, ArrayType<ArrayType<FloatType>>, OptionType<ArrayType<ArrayType<FloatType>>>, StructType<{
812
+ /** Number of steps to generate */
813
+ n_steps: IntegerType;
814
+ /** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
815
+ temperature: FloatType;
816
+ /** If true, return probabilities. If false, return samples. */
817
+ return_probs: BooleanType;
818
+ }>], ArrayType<ArrayType<FloatType>>>;
819
+ /**
820
+ * Configuration for Decision Transformer trajectory generation.
821
+ */
822
+ export declare const TrajectoryGenerateConfigType: StructType<{
823
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
824
+ temperature: FloatType;
825
+ /** Whether to return probabilities or samples */
826
+ return_probs: BooleanType;
827
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
828
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
829
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
830
+ temporal_mask: OptionType<ArrayType<FloatType>>;
831
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
832
+ head_configs: OptionType<ArrayType<StructType<{
833
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
834
+ head_type: VariantType<{
835
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
836
+ binary: NullType;
837
+ /** Multi-class output: n_classes logits, softmax, CE loss */
838
+ multiclass: StructType<{
839
+ n_classes: IntegerType;
840
+ }>;
841
+ }>;
842
+ /** Optional class weights for this head */
843
+ class_weights: OptionType<ArrayType<FloatType>>;
844
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
845
+ conditional_on: OptionType<IntegerType>;
846
+ }>>>;
847
+ }>;
848
+ /**
849
+ * Train with trajectory data for return-conditioned sequence generation.
850
+ *
851
+ * Use with decision_transformer architecture.
852
+ *
853
+ * @param returns - Return per sample (n_samples,) - actual outcome achieved
854
+ * @param states - State matrices: n_samples × (seq_len, state_dim)
855
+ * @param actions - Action matrices: n_samples × (seq_len, action_dim)
856
+ * @param masks - Temporal masks: n_samples × (seq_len,) - valid timesteps
857
+ * @param config - Training configuration with decision_transformer architecture
858
+ * @returns Training result with model blob and metrics
859
+ */
860
+ export declare const lightning_train_trajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
861
+ /** Model architecture */
862
+ architecture: VariantType<{
609
863
  /** Simple MLP: input → hidden → output */
610
864
  mlp: StructType<{
611
865
  /** Hidden layer sizes */
@@ -676,111 +930,498 @@ export declare const LightningTypes: {
676
930
  /** Optional condition dimension for conditional generation */
677
931
  condition_dim: OptionType<IntegerType>;
678
932
  }>;
679
- }>;
680
- readonly CellType: VariantType<{
681
- lstm: NullType;
682
- gru: NullType;
683
- }>;
684
- readonly EpochCallbackType: FunctionType<[IntegerType, FloatType, FloatType], NullType>;
685
- readonly ConfigType: StructType<{
686
- /** Model architecture */
687
- architecture: VariantType<{
688
- /** Simple MLP: input → hidden → output */
689
- mlp: StructType<{
690
- /** Hidden layer sizes */
691
- hidden_layers: ArrayType<IntegerType>;
692
- }>;
693
- /** Autoencoder: input encoder → latent → decoder → output */
694
- autoencoder: StructType<{
695
- /** Encoder hidden layer sizes */
696
- encoder_layers: ArrayType<IntegerType>;
697
- /** Latent dimension (bottleneck) */
698
- latent_dim: IntegerType;
699
- /** Decoder hidden layer sizes */
700
- decoder_layers: ArrayType<IntegerType>;
701
- }>;
702
- /** Conv1D: 1D convolutional autoencoder for temporal patterns */
703
- conv1d: StructType<{
704
- /** Number of channels (e.g., additive types) */
705
- n_channels: IntegerType;
706
- /** Sequence length (e.g., days) */
707
- sequence_length: IntegerType;
708
- /** Conv layer channel sizes */
709
- conv_channels: ArrayType<IntegerType>;
710
- /** Kernel size for convolutions (must be odd) */
711
- kernel_size: IntegerType;
712
- /** Latent dimension after flattening */
713
- latent_dim: IntegerType;
714
- /** Optional condition dimension for conditional generation */
715
- condition_dim: OptionType<IntegerType>;
716
- }>;
717
- /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
718
- sequential: StructType<{
719
- /** Number of channels (e.g., additive types) */
720
- n_channels: IntegerType;
721
- /** Sequence length (e.g., days) */
722
- sequence_length: IntegerType;
723
- /** RNN hidden size */
724
- hidden_size: IntegerType;
725
- /** Number of RNN layers */
726
- n_layers: IntegerType;
727
- /** Cell type: lstm or gru */
728
- cell_type: VariantType<{
729
- lstm: NullType;
730
- gru: NullType;
731
- }>;
732
- /** Latent dimension (from final hidden state) */
733
- latent_dim: IntegerType;
734
- /** Bidirectional encoder (decoder is always unidirectional) */
735
- bidirectional: BooleanType;
736
- /** Optional condition dimension for conditional generation */
737
- condition_dim: OptionType<IntegerType>;
738
- }>;
739
- /** Transformer: attention-based autoencoder for complex patterns */
740
- transformer: StructType<{
741
- /** Number of channels (e.g., additive types) */
742
- n_channels: IntegerType;
743
- /** Sequence length (e.g., days) */
744
- sequence_length: IntegerType;
745
- /** Model dimension */
746
- d_model: IntegerType;
747
- /** Number of attention heads (must divide d_model evenly) */
748
- n_attention_heads: IntegerType;
749
- /** Number of transformer layers */
750
- n_layers: IntegerType;
751
- /** Feedforward dimension (default: 4 * d_model) */
752
- d_ff: OptionType<IntegerType>;
753
- /** Latent dimension (mean pooled output) */
754
- latent_dim: IntegerType;
755
- /** Optional condition dimension for conditional generation */
756
- condition_dim: OptionType<IntegerType>;
933
+ /**
934
+ * Decision Transformer: return-conditioned sequence generation.
935
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
936
+ * Predicts actions conditioned on desired return and state history.
937
+ */
938
+ decision_transformer: StructType<{
939
+ /** Sequence length (timesteps) */
940
+ sequence_length: IntegerType;
941
+ /** State dimension per timestep */
942
+ state_dim: IntegerType;
943
+ /** Action dimension per timestep */
944
+ action_dim: IntegerType;
945
+ /** Model dimension (transformer hidden size) */
946
+ d_model: IntegerType;
947
+ /** Number of attention heads */
948
+ n_attention_heads: IntegerType;
949
+ /** Number of transformer layers */
950
+ n_layers: IntegerType;
951
+ /** Feedforward dimension (default: 4 * d_model) */
952
+ d_ff: OptionType<IntegerType>;
953
+ /** Dropout rate */
954
+ dropout: OptionType<FloatType>;
955
+ /** Whether return is per-timestep or global */
956
+ return_embedding: VariantType<{
957
+ /** Single return value for entire sequence */
958
+ global: NullType;
959
+ /** Return-to-go at each timestep */
960
+ per_timestep: NullType;
757
961
  }>;
758
962
  }>;
759
- /** Output mode (determines loss function) */
760
- output: VariantType<{
761
- /** Regression: MSE loss, no activation */
762
- regression: NullType;
763
- /** Binary: BCE loss, sigmoid activation */
764
- binary: StructType<{
765
- /** Optional per-position pos_weights for class imbalance [output_dim] */
766
- pos_weight: OptionType<ArrayType<FloatType>>;
767
- }>;
768
- /** Multiclass: CrossEntropy loss, softmax activation */
769
- multiclass: StructType<{
770
- /** Number of classes */
771
- n_classes: IntegerType;
772
- /** Optional per-class weights */
963
+ }>;
964
+ /** Output mode (determines loss function) */
965
+ output: VariantType<{
966
+ /** Regression: MSE loss, no activation */
967
+ regression: NullType;
968
+ /** Binary: BCE loss, sigmoid activation */
969
+ binary: StructType<{
970
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
971
+ pos_weight: OptionType<ArrayType<FloatType>>;
972
+ }>;
973
+ /** Multiclass: CrossEntropy loss, softmax activation */
974
+ multiclass: StructType<{
975
+ /** Number of classes */
976
+ n_classes: IntegerType;
977
+ /** Optional per-class weights */
978
+ class_weights: OptionType<ArrayType<FloatType>>;
979
+ }>;
980
+ /** Multi-head categorical: N independent CrossEntropy heads */
981
+ multi_head: StructType<{
982
+ /** Number of heads (e.g., 84 time slots) */
983
+ n_heads: IntegerType;
984
+ /** Classes per head (e.g., 4 bins) */
985
+ n_classes_per_head: IntegerType;
986
+ /** Optional class weights matrix (n_heads, n_classes) */
987
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
988
+ }>;
989
+ /**
990
+ * Mixed output types per head.
991
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
992
+ * Binary heads: 1 logit → sigmoid → BCE loss
993
+ * Multiclass heads: n_classes logits → softmax → CE loss
994
+ * Action vectors use one-hot encoding for multiclass heads.
995
+ */
996
+ multi_head_mixed: StructType<{
997
+ /** Array of head configurations */
998
+ heads: ArrayType<StructType<{
999
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1000
+ head_type: VariantType<{
1001
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1002
+ binary: NullType;
1003
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1004
+ multiclass: StructType<{
1005
+ n_classes: IntegerType;
1006
+ }>;
1007
+ }>;
1008
+ /** Optional class weights for this head */
773
1009
  class_weights: OptionType<ArrayType<FloatType>>;
774
- }>;
775
- /** Multi-head categorical: N independent CrossEntropy heads */
776
- multi_head: StructType<{
777
- /** Number of heads (e.g., 84 time slots) */
778
- n_heads: IntegerType;
779
- /** Classes per head (e.g., 4 bins) */
780
- n_classes_per_head: IntegerType;
781
- /** Optional class weights matrix (n_heads, n_classes) */
782
- class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
783
- }>;
1010
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1011
+ conditional_on: OptionType<IntegerType>;
1012
+ }>>;
1013
+ }>;
1014
+ }>;
1015
+ /** Learning rate (default: 1e-3) */
1016
+ learning_rate: OptionType<FloatType>;
1017
+ /** Maximum epochs (default: 100) */
1018
+ max_epochs: OptionType<IntegerType>;
1019
+ /** Early stopping patience (default: 10) */
1020
+ patience: OptionType<IntegerType>;
1021
+ /** Batch size (default: 32) */
1022
+ batch_size: OptionType<IntegerType>;
1023
+ /** Dropout rate (default: 0.1) */
1024
+ dropout: OptionType<FloatType>;
1025
+ /** Gradient clipping value (default: 1.0) */
1026
+ gradient_clip: OptionType<FloatType>;
1027
+ /** L2 regularization weight decay (default: 0) */
1028
+ weight_decay: OptionType<FloatType>;
1029
+ /** Random seed for reproducibility */
1030
+ random_state: OptionType<IntegerType>;
1031
+ /** Optional callback called each epoch */
1032
+ epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
1033
+ }>], StructType<{
1034
+ /** Trained model blob */
1035
+ model: VariantType<{
1036
+ lightning: StructType<{
1037
+ /** Serialized model data (state_dict + hparams) */
1038
+ data: BlobType;
1039
+ /** Input dimension */
1040
+ n_features: IntegerType;
1041
+ /** Output dimension */
1042
+ output_dim: IntegerType;
1043
+ /** Architecture type */
1044
+ architecture_type: StringType;
1045
+ /** Output type */
1046
+ output_type: StringType;
1047
+ /** Latent dimension (autoencoder only) */
1048
+ latent_dim: OptionType<IntegerType>;
1049
+ }>;
1050
+ }>;
1051
+ /** Final training loss */
1052
+ train_loss: FloatType;
1053
+ /** Final validation loss */
1054
+ val_loss: FloatType;
1055
+ /** Best epoch (for early stopping) */
1056
+ best_epoch: IntegerType;
1057
+ }>>;
1058
+ /**
1059
+ * Generate action sequences autoregressively from trajectory model.
1060
+ *
1061
+ * Use with models trained via trainTrajectory.
1062
+ *
1063
+ * @param model - Trained model from trainTrajectory
1064
+ * @param states - State matrices: n_samples × (seq_len, state_dim)
1065
+ * @param target_returns - Target returns: (n_samples,)
1066
+ * @param config - Generation configuration
1067
+ * @returns Generated actions: n_samples × (seq_len, action_dim)
1068
+ */
1069
+ export declare const lightning_generate_trajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
1070
+ lightning: StructType<{
1071
+ /** Serialized model data (state_dict + hparams) */
1072
+ data: BlobType;
1073
+ /** Input dimension */
1074
+ n_features: IntegerType;
1075
+ /** Output dimension */
1076
+ output_dim: IntegerType;
1077
+ /** Architecture type */
1078
+ architecture_type: StringType;
1079
+ /** Output type */
1080
+ output_type: StringType;
1081
+ /** Latent dimension (autoencoder only) */
1082
+ latent_dim: OptionType<IntegerType>;
1083
+ }>;
1084
+ }>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
1085
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
1086
+ temperature: FloatType;
1087
+ /** Whether to return probabilities or samples */
1088
+ return_probs: BooleanType;
1089
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
1090
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
1091
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
1092
+ temporal_mask: OptionType<ArrayType<FloatType>>;
1093
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
1094
+ head_configs: OptionType<ArrayType<StructType<{
1095
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1096
+ head_type: VariantType<{
1097
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1098
+ binary: NullType;
1099
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1100
+ multiclass: StructType<{
1101
+ n_classes: IntegerType;
1102
+ }>;
1103
+ }>;
1104
+ /** Optional class weights for this head */
1105
+ class_weights: OptionType<ArrayType<FloatType>>;
1106
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1107
+ conditional_on: OptionType<IntegerType>;
1108
+ }>>>;
1109
+ }>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
1110
+ /**
1111
+ * Lightning types namespace.
1112
+ */
1113
+ export declare const LightningTypes: {
1114
+ readonly OutputType: VariantType<{
1115
+ /** Regression: MSE loss, no activation */
1116
+ regression: NullType;
1117
+ /** Binary: BCE loss, sigmoid activation */
1118
+ binary: StructType<{
1119
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
1120
+ pos_weight: OptionType<ArrayType<FloatType>>;
1121
+ }>;
1122
+ /** Multiclass: CrossEntropy loss, softmax activation */
1123
+ multiclass: StructType<{
1124
+ /** Number of classes */
1125
+ n_classes: IntegerType;
1126
+ /** Optional per-class weights */
1127
+ class_weights: OptionType<ArrayType<FloatType>>;
1128
+ }>;
1129
+ /** Multi-head categorical: N independent CrossEntropy heads */
1130
+ multi_head: StructType<{
1131
+ /** Number of heads (e.g., 84 time slots) */
1132
+ n_heads: IntegerType;
1133
+ /** Classes per head (e.g., 4 bins) */
1134
+ n_classes_per_head: IntegerType;
1135
+ /** Optional class weights matrix (n_heads, n_classes) */
1136
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1137
+ }>;
1138
+ /**
1139
+ * Mixed output types per head.
1140
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1141
+ * Binary heads: 1 logit → sigmoid → BCE loss
1142
+ * Multiclass heads: n_classes logits → softmax → CE loss
1143
+ * Action vectors use one-hot encoding for multiclass heads.
1144
+ */
1145
+ multi_head_mixed: StructType<{
1146
+ /** Array of head configurations */
1147
+ heads: ArrayType<StructType<{
1148
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1149
+ head_type: VariantType<{
1150
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1151
+ binary: NullType;
1152
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1153
+ multiclass: StructType<{
1154
+ n_classes: IntegerType;
1155
+ }>;
1156
+ }>;
1157
+ /** Optional class weights for this head */
1158
+ class_weights: OptionType<ArrayType<FloatType>>;
1159
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1160
+ conditional_on: OptionType<IntegerType>;
1161
+ }>>;
1162
+ }>;
1163
+ }>;
1164
+ readonly ArchitectureType: VariantType<{
1165
+ /** Simple MLP: input → hidden → output */
1166
+ mlp: StructType<{
1167
+ /** Hidden layer sizes */
1168
+ hidden_layers: ArrayType<IntegerType>;
1169
+ }>;
1170
+ /** Autoencoder: input → encoder → latent → decoder → output */
1171
+ autoencoder: StructType<{
1172
+ /** Encoder hidden layer sizes */
1173
+ encoder_layers: ArrayType<IntegerType>;
1174
+ /** Latent dimension (bottleneck) */
1175
+ latent_dim: IntegerType;
1176
+ /** Decoder hidden layer sizes */
1177
+ decoder_layers: ArrayType<IntegerType>;
1178
+ }>;
1179
+ /** Conv1D: 1D convolutional autoencoder for temporal patterns */
1180
+ conv1d: StructType<{
1181
+ /** Number of channels (e.g., additive types) */
1182
+ n_channels: IntegerType;
1183
+ /** Sequence length (e.g., days) */
1184
+ sequence_length: IntegerType;
1185
+ /** Conv layer channel sizes */
1186
+ conv_channels: ArrayType<IntegerType>;
1187
+ /** Kernel size for convolutions (must be odd) */
1188
+ kernel_size: IntegerType;
1189
+ /** Latent dimension after flattening */
1190
+ latent_dim: IntegerType;
1191
+ /** Optional condition dimension for conditional generation */
1192
+ condition_dim: OptionType<IntegerType>;
1193
+ }>;
1194
+ /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
1195
+ sequential: StructType<{
1196
+ /** Number of channels (e.g., additive types) */
1197
+ n_channels: IntegerType;
1198
+ /** Sequence length (e.g., days) */
1199
+ sequence_length: IntegerType;
1200
+ /** RNN hidden size */
1201
+ hidden_size: IntegerType;
1202
+ /** Number of RNN layers */
1203
+ n_layers: IntegerType;
1204
+ /** Cell type: lstm or gru */
1205
+ cell_type: VariantType<{
1206
+ lstm: NullType;
1207
+ gru: NullType;
1208
+ }>;
1209
+ /** Latent dimension (from final hidden state) */
1210
+ latent_dim: IntegerType;
1211
+ /** Bidirectional encoder (decoder is always unidirectional) */
1212
+ bidirectional: BooleanType;
1213
+ /** Optional condition dimension for conditional generation */
1214
+ condition_dim: OptionType<IntegerType>;
1215
+ }>;
1216
+ /** Transformer: attention-based autoencoder for complex patterns */
1217
+ transformer: StructType<{
1218
+ /** Number of channels (e.g., additive types) */
1219
+ n_channels: IntegerType;
1220
+ /** Sequence length (e.g., days) */
1221
+ sequence_length: IntegerType;
1222
+ /** Model dimension */
1223
+ d_model: IntegerType;
1224
+ /** Number of attention heads (must divide d_model evenly) */
1225
+ n_attention_heads: IntegerType;
1226
+ /** Number of transformer layers */
1227
+ n_layers: IntegerType;
1228
+ /** Feedforward dimension (default: 4 * d_model) */
1229
+ d_ff: OptionType<IntegerType>;
1230
+ /** Latent dimension (mean pooled output) */
1231
+ latent_dim: IntegerType;
1232
+ /** Optional condition dimension for conditional generation */
1233
+ condition_dim: OptionType<IntegerType>;
1234
+ }>;
1235
+ /**
1236
+ * Decision Transformer: return-conditioned sequence generation.
1237
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1238
+ * Predicts actions conditioned on desired return and state history.
1239
+ */
1240
+ decision_transformer: StructType<{
1241
+ /** Sequence length (timesteps) */
1242
+ sequence_length: IntegerType;
1243
+ /** State dimension per timestep */
1244
+ state_dim: IntegerType;
1245
+ /** Action dimension per timestep */
1246
+ action_dim: IntegerType;
1247
+ /** Model dimension (transformer hidden size) */
1248
+ d_model: IntegerType;
1249
+ /** Number of attention heads */
1250
+ n_attention_heads: IntegerType;
1251
+ /** Number of transformer layers */
1252
+ n_layers: IntegerType;
1253
+ /** Feedforward dimension (default: 4 * d_model) */
1254
+ d_ff: OptionType<IntegerType>;
1255
+ /** Dropout rate */
1256
+ dropout: OptionType<FloatType>;
1257
+ /** Whether return is per-timestep or global */
1258
+ return_embedding: VariantType<{
1259
+ /** Single return value for entire sequence */
1260
+ global: NullType;
1261
+ /** Return-to-go at each timestep */
1262
+ per_timestep: NullType;
1263
+ }>;
1264
+ }>;
1265
+ }>;
1266
+ readonly CellType: VariantType<{
1267
+ lstm: NullType;
1268
+ gru: NullType;
1269
+ }>;
1270
+ readonly EpochCallbackType: FunctionType<[IntegerType, FloatType, FloatType], NullType>;
1271
+ readonly ConfigType: StructType<{
1272
+ /** Model architecture */
1273
+ architecture: VariantType<{
1274
+ /** Simple MLP: input → hidden → output */
1275
+ mlp: StructType<{
1276
+ /** Hidden layer sizes */
1277
+ hidden_layers: ArrayType<IntegerType>;
1278
+ }>;
1279
+ /** Autoencoder: input → encoder → latent → decoder → output */
1280
+ autoencoder: StructType<{
1281
+ /** Encoder hidden layer sizes */
1282
+ encoder_layers: ArrayType<IntegerType>;
1283
+ /** Latent dimension (bottleneck) */
1284
+ latent_dim: IntegerType;
1285
+ /** Decoder hidden layer sizes */
1286
+ decoder_layers: ArrayType<IntegerType>;
1287
+ }>;
1288
+ /** Conv1D: 1D convolutional autoencoder for temporal patterns */
1289
+ conv1d: StructType<{
1290
+ /** Number of channels (e.g., additive types) */
1291
+ n_channels: IntegerType;
1292
+ /** Sequence length (e.g., days) */
1293
+ sequence_length: IntegerType;
1294
+ /** Conv layer channel sizes */
1295
+ conv_channels: ArrayType<IntegerType>;
1296
+ /** Kernel size for convolutions (must be odd) */
1297
+ kernel_size: IntegerType;
1298
+ /** Latent dimension after flattening */
1299
+ latent_dim: IntegerType;
1300
+ /** Optional condition dimension for conditional generation */
1301
+ condition_dim: OptionType<IntegerType>;
1302
+ }>;
1303
+ /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
1304
+ sequential: StructType<{
1305
+ /** Number of channels (e.g., additive types) */
1306
+ n_channels: IntegerType;
1307
+ /** Sequence length (e.g., days) */
1308
+ sequence_length: IntegerType;
1309
+ /** RNN hidden size */
1310
+ hidden_size: IntegerType;
1311
+ /** Number of RNN layers */
1312
+ n_layers: IntegerType;
1313
+ /** Cell type: lstm or gru */
1314
+ cell_type: VariantType<{
1315
+ lstm: NullType;
1316
+ gru: NullType;
1317
+ }>;
1318
+ /** Latent dimension (from final hidden state) */
1319
+ latent_dim: IntegerType;
1320
+ /** Bidirectional encoder (decoder is always unidirectional) */
1321
+ bidirectional: BooleanType;
1322
+ /** Optional condition dimension for conditional generation */
1323
+ condition_dim: OptionType<IntegerType>;
1324
+ }>;
1325
+ /** Transformer: attention-based autoencoder for complex patterns */
1326
+ transformer: StructType<{
1327
+ /** Number of channels (e.g., additive types) */
1328
+ n_channels: IntegerType;
1329
+ /** Sequence length (e.g., days) */
1330
+ sequence_length: IntegerType;
1331
+ /** Model dimension */
1332
+ d_model: IntegerType;
1333
+ /** Number of attention heads (must divide d_model evenly) */
1334
+ n_attention_heads: IntegerType;
1335
+ /** Number of transformer layers */
1336
+ n_layers: IntegerType;
1337
+ /** Feedforward dimension (default: 4 * d_model) */
1338
+ d_ff: OptionType<IntegerType>;
1339
+ /** Latent dimension (mean pooled output) */
1340
+ latent_dim: IntegerType;
1341
+ /** Optional condition dimension for conditional generation */
1342
+ condition_dim: OptionType<IntegerType>;
1343
+ }>;
1344
+ /**
1345
+ * Decision Transformer: return-conditioned sequence generation.
1346
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1347
+ * Predicts actions conditioned on desired return and state history.
1348
+ */
1349
+ decision_transformer: StructType<{
1350
+ /** Sequence length (timesteps) */
1351
+ sequence_length: IntegerType;
1352
+ /** State dimension per timestep */
1353
+ state_dim: IntegerType;
1354
+ /** Action dimension per timestep */
1355
+ action_dim: IntegerType;
1356
+ /** Model dimension (transformer hidden size) */
1357
+ d_model: IntegerType;
1358
+ /** Number of attention heads */
1359
+ n_attention_heads: IntegerType;
1360
+ /** Number of transformer layers */
1361
+ n_layers: IntegerType;
1362
+ /** Feedforward dimension (default: 4 * d_model) */
1363
+ d_ff: OptionType<IntegerType>;
1364
+ /** Dropout rate */
1365
+ dropout: OptionType<FloatType>;
1366
+ /** Whether return is per-timestep or global */
1367
+ return_embedding: VariantType<{
1368
+ /** Single return value for entire sequence */
1369
+ global: NullType;
1370
+ /** Return-to-go at each timestep */
1371
+ per_timestep: NullType;
1372
+ }>;
1373
+ }>;
1374
+ }>;
1375
+ /** Output mode (determines loss function) */
1376
+ output: VariantType<{
1377
+ /** Regression: MSE loss, no activation */
1378
+ regression: NullType;
1379
+ /** Binary: BCE loss, sigmoid activation */
1380
+ binary: StructType<{
1381
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
1382
+ pos_weight: OptionType<ArrayType<FloatType>>;
1383
+ }>;
1384
+ /** Multiclass: CrossEntropy loss, softmax activation */
1385
+ multiclass: StructType<{
1386
+ /** Number of classes */
1387
+ n_classes: IntegerType;
1388
+ /** Optional per-class weights */
1389
+ class_weights: OptionType<ArrayType<FloatType>>;
1390
+ }>;
1391
+ /** Multi-head categorical: N independent CrossEntropy heads */
1392
+ multi_head: StructType<{
1393
+ /** Number of heads (e.g., 84 time slots) */
1394
+ n_heads: IntegerType;
1395
+ /** Classes per head (e.g., 4 bins) */
1396
+ n_classes_per_head: IntegerType;
1397
+ /** Optional class weights matrix (n_heads, n_classes) */
1398
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1399
+ }>;
1400
+ /**
1401
+ * Mixed output types per head.
1402
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1403
+ * Binary heads: 1 logit → sigmoid → BCE loss
1404
+ * Multiclass heads: n_classes logits → softmax → CE loss
1405
+ * Action vectors use one-hot encoding for multiclass heads.
1406
+ */
1407
+ multi_head_mixed: StructType<{
1408
+ /** Array of head configurations */
1409
+ heads: ArrayType<StructType<{
1410
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1411
+ head_type: VariantType<{
1412
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1413
+ binary: NullType;
1414
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1415
+ multiclass: StructType<{
1416
+ n_classes: IntegerType;
1417
+ }>;
1418
+ }>;
1419
+ /** Optional class weights for this head */
1420
+ class_weights: OptionType<ArrayType<FloatType>>;
1421
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1422
+ conditional_on: OptionType<IntegerType>;
1423
+ }>>;
1424
+ }>;
784
1425
  }>;
785
1426
  /** Learning rate (default: 1e-3) */
786
1427
  learning_rate: OptionType<FloatType>;
@@ -854,6 +1495,61 @@ export declare const LightningTypes: {
854
1495
  /** Group index per sample: [n_samples] */
855
1496
  sample_groups: ArrayType<IntegerType>;
856
1497
  }>;
1498
+ readonly GenerateConfigType: StructType<{
1499
+ /** Number of steps to generate */
1500
+ n_steps: IntegerType;
1501
+ /** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
1502
+ temperature: FloatType;
1503
+ /** If true, return probabilities. If false, return samples. */
1504
+ return_probs: BooleanType;
1505
+ }>;
1506
+ readonly ReturnEmbeddingType: VariantType<{
1507
+ /** Single return value for entire sequence */
1508
+ global: NullType;
1509
+ /** Return-to-go at each timestep */
1510
+ per_timestep: NullType;
1511
+ }>;
1512
+ readonly HeadConfigType: StructType<{
1513
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1514
+ head_type: VariantType<{
1515
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1516
+ binary: NullType;
1517
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1518
+ multiclass: StructType<{
1519
+ n_classes: IntegerType;
1520
+ }>;
1521
+ }>;
1522
+ /** Optional class weights for this head */
1523
+ class_weights: OptionType<ArrayType<FloatType>>;
1524
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1525
+ conditional_on: OptionType<IntegerType>;
1526
+ }>;
1527
+ readonly TrajectoryGenerateConfigType: StructType<{
1528
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
1529
+ temperature: FloatType;
1530
+ /** Whether to return probabilities or samples */
1531
+ return_probs: BooleanType;
1532
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
1533
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
1534
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
1535
+ temporal_mask: OptionType<ArrayType<FloatType>>;
1536
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
1537
+ head_configs: OptionType<ArrayType<StructType<{
1538
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1539
+ head_type: VariantType<{
1540
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1541
+ binary: NullType;
1542
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1543
+ multiclass: StructType<{
1544
+ n_classes: IntegerType;
1545
+ }>;
1546
+ }>;
1547
+ /** Optional class weights for this head */
1548
+ class_weights: OptionType<ArrayType<FloatType>>;
1549
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1550
+ conditional_on: OptionType<IntegerType>;
1551
+ }>>>;
1552
+ }>;
857
1553
  };
858
1554
  /**
859
1555
  * Lightning platform functions namespace.
@@ -959,6 +1655,36 @@ export declare const Lightning: {
959
1655
  /** Optional condition dimension for conditional generation */
960
1656
  condition_dim: OptionType<IntegerType>;
961
1657
  }>;
1658
+ /**
1659
+ * Decision Transformer: return-conditioned sequence generation.
1660
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
1661
+ * Predicts actions conditioned on desired return and state history.
1662
+ */
1663
+ decision_transformer: StructType<{
1664
+ /** Sequence length (timesteps) */
1665
+ sequence_length: IntegerType;
1666
+ /** State dimension per timestep */
1667
+ state_dim: IntegerType;
1668
+ /** Action dimension per timestep */
1669
+ action_dim: IntegerType;
1670
+ /** Model dimension (transformer hidden size) */
1671
+ d_model: IntegerType;
1672
+ /** Number of attention heads */
1673
+ n_attention_heads: IntegerType;
1674
+ /** Number of transformer layers */
1675
+ n_layers: IntegerType;
1676
+ /** Feedforward dimension (default: 4 * d_model) */
1677
+ d_ff: OptionType<IntegerType>;
1678
+ /** Dropout rate */
1679
+ dropout: OptionType<FloatType>;
1680
+ /** Whether return is per-timestep or global */
1681
+ return_embedding: VariantType<{
1682
+ /** Single return value for entire sequence */
1683
+ global: NullType;
1684
+ /** Return-to-go at each timestep */
1685
+ per_timestep: NullType;
1686
+ }>;
1687
+ }>;
962
1688
  }>;
963
1689
  /** Output mode (determines loss function) */
964
1690
  output: VariantType<{
@@ -985,6 +1711,31 @@ export declare const Lightning: {
985
1711
  /** Optional class weights matrix (n_heads, n_classes) */
986
1712
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
987
1713
  }>;
1714
+ /**
1715
+ * Mixed output types per head.
1716
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
1717
+ * Binary heads: 1 logit → sigmoid → BCE loss
1718
+ * Multiclass heads: n_classes logits → softmax → CE loss
1719
+ * Action vectors use one-hot encoding for multiclass heads.
1720
+ */
1721
+ multi_head_mixed: StructType<{
1722
+ /** Array of head configurations */
1723
+ heads: ArrayType<StructType<{
1724
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
1725
+ head_type: VariantType<{
1726
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
1727
+ binary: NullType;
1728
+ /** Multi-class output: n_classes logits, softmax, CE loss */
1729
+ multiclass: StructType<{
1730
+ n_classes: IntegerType;
1731
+ }>;
1732
+ }>;
1733
+ /** Optional class weights for this head */
1734
+ class_weights: OptionType<ArrayType<FloatType>>;
1735
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
1736
+ conditional_on: OptionType<IntegerType>;
1737
+ }>>;
1738
+ }>;
988
1739
  }>;
989
1740
  /** Learning rate (default: 1e-3) */
990
1741
  learning_rate: OptionType<FloatType>;
@@ -1125,6 +1876,308 @@ export declare const Lightning: {
1125
1876
  latent_dim: OptionType<IntegerType>;
1126
1877
  }>;
1127
1878
  }>, ArrayType<ArrayType<FloatType>>, ArrayType<ArrayType<FloatType>>], ArrayType<ArrayType<FloatType>>>;
1879
+ /**
1880
+ * Generate sequence autoregressively.
1881
+ *
1882
+ * Generates a sequence from a trained sequential model, optionally
1883
+ * continuing from a prefix and conditioned on input features.
1884
+ */
1885
+ readonly generateSequence: import("@elaraai/east").PlatformDefinition<[VariantType<{
1886
+ lightning: StructType<{
1887
+ /** Serialized model data (state_dict + hparams) */
1888
+ data: BlobType;
1889
+ /** Input dimension */
1890
+ n_features: IntegerType;
1891
+ /** Output dimension */
1892
+ output_dim: IntegerType;
1893
+ /** Architecture type */
1894
+ architecture_type: StringType;
1895
+ /** Output type */
1896
+ output_type: StringType;
1897
+ /** Latent dimension (autoencoder only) */
1898
+ latent_dim: OptionType<IntegerType>;
1899
+ }>;
1900
+ }>, ArrayType<ArrayType<FloatType>>, OptionType<ArrayType<ArrayType<FloatType>>>, StructType<{
1901
+ /** Number of steps to generate */
1902
+ n_steps: IntegerType;
1903
+ /** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
1904
+ temperature: FloatType;
1905
+ /** If true, return probabilities. If false, return samples. */
1906
+ return_probs: BooleanType;
1907
+ }>], ArrayType<ArrayType<FloatType>>>;
1908
+ /**
1909
+ * Train a Decision Transformer with trajectory data.
1910
+ *
1911
+ * Trains a return-conditioned sequence generation model that learns
1912
+ * to predict actions given states and desired returns.
1913
+ *
1914
+ * @example
1915
+ * ```typescript
1916
+ * const result = Lightning.trainTrajectory(
1917
+ * returns, states, actions, masks,
1918
+ * {
1919
+ * architecture: variant("decision_transformer", {
1920
+ * sequence_length: 14n,
1921
+ * state_dim: 8n,
1922
+ * action_dim: 11n,
1923
+ * d_model: 64n,
1924
+ * n_attention_heads: 4n,
1925
+ * n_layers: 3n,
1926
+ * d_ff: variant("none", null),
1927
+ * dropout: variant("some", 0.1),
1928
+ * return_embedding: variant("global", null),
1929
+ * }),
1930
+ * output: variant("multi_head_mixed", { heads: [...] }),
1931
+ * ...
1932
+ * }
1933
+ * );
1934
+ * ```
1935
+ */
1936
+ readonly trainTrajectory: import("@elaraai/east").PlatformDefinition<[ArrayType<FloatType>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<ArrayType<FloatType>>, StructType<{
1937
+ /** Model architecture */
1938
+ architecture: VariantType<{
1939
+ /** Simple MLP: input → hidden → output */
1940
+ mlp: StructType<{
1941
+ /** Hidden layer sizes */
1942
+ hidden_layers: ArrayType<IntegerType>;
1943
+ }>;
1944
+ /** Autoencoder: input → encoder → latent → decoder → output */
1945
+ autoencoder: StructType<{
1946
+ /** Encoder hidden layer sizes */
1947
+ encoder_layers: ArrayType<IntegerType>;
1948
+ /** Latent dimension (bottleneck) */
1949
+ latent_dim: IntegerType;
1950
+ /** Decoder hidden layer sizes */
1951
+ decoder_layers: ArrayType<IntegerType>;
1952
+ }>;
1953
+ /** Conv1D: 1D convolutional autoencoder for temporal patterns */
1954
+ conv1d: StructType<{
1955
+ /** Number of channels (e.g., additive types) */
1956
+ n_channels: IntegerType;
1957
+ /** Sequence length (e.g., days) */
1958
+ sequence_length: IntegerType;
1959
+ /** Conv layer channel sizes */
1960
+ conv_channels: ArrayType<IntegerType>;
1961
+ /** Kernel size for convolutions (must be odd) */
1962
+ kernel_size: IntegerType;
1963
+ /** Latent dimension after flattening */
1964
+ latent_dim: IntegerType;
1965
+ /** Optional condition dimension for conditional generation */
1966
+ condition_dim: OptionType<IntegerType>;
1967
+ }>;
1968
+ /** Sequential: LSTM/GRU autoencoder for long-range dependencies */
1969
+ sequential: StructType<{
1970
+ /** Number of channels (e.g., additive types) */
1971
+ n_channels: IntegerType;
1972
+ /** Sequence length (e.g., days) */
1973
+ sequence_length: IntegerType;
1974
+ /** RNN hidden size */
1975
+ hidden_size: IntegerType;
1976
+ /** Number of RNN layers */
1977
+ n_layers: IntegerType;
1978
+ /** Cell type: lstm or gru */
1979
+ cell_type: VariantType<{
1980
+ lstm: NullType;
1981
+ gru: NullType;
1982
+ }>;
1983
+ /** Latent dimension (from final hidden state) */
1984
+ latent_dim: IntegerType;
1985
+ /** Bidirectional encoder (decoder is always unidirectional) */
1986
+ bidirectional: BooleanType;
1987
+ /** Optional condition dimension for conditional generation */
1988
+ condition_dim: OptionType<IntegerType>;
1989
+ }>;
1990
+ /** Transformer: attention-based autoencoder for complex patterns */
1991
+ transformer: StructType<{
1992
+ /** Number of channels (e.g., additive types) */
1993
+ n_channels: IntegerType;
1994
+ /** Sequence length (e.g., days) */
1995
+ sequence_length: IntegerType;
1996
+ /** Model dimension */
1997
+ d_model: IntegerType;
1998
+ /** Number of attention heads (must divide d_model evenly) */
1999
+ n_attention_heads: IntegerType;
2000
+ /** Number of transformer layers */
2001
+ n_layers: IntegerType;
2002
+ /** Feedforward dimension (default: 4 * d_model) */
2003
+ d_ff: OptionType<IntegerType>;
2004
+ /** Latent dimension (mean pooled output) */
2005
+ latent_dim: IntegerType;
2006
+ /** Optional condition dimension for conditional generation */
2007
+ condition_dim: OptionType<IntegerType>;
2008
+ }>;
2009
+ /**
2010
+ * Decision Transformer: return-conditioned sequence generation.
2011
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2012
+ * Predicts actions conditioned on desired return and state history.
2013
+ */
2014
+ decision_transformer: StructType<{
2015
+ /** Sequence length (timesteps) */
2016
+ sequence_length: IntegerType;
2017
+ /** State dimension per timestep */
2018
+ state_dim: IntegerType;
2019
+ /** Action dimension per timestep */
2020
+ action_dim: IntegerType;
2021
+ /** Model dimension (transformer hidden size) */
2022
+ d_model: IntegerType;
2023
+ /** Number of attention heads */
2024
+ n_attention_heads: IntegerType;
2025
+ /** Number of transformer layers */
2026
+ n_layers: IntegerType;
2027
+ /** Feedforward dimension (default: 4 * d_model) */
2028
+ d_ff: OptionType<IntegerType>;
2029
+ /** Dropout rate */
2030
+ dropout: OptionType<FloatType>;
2031
+ /** Whether return is per-timestep or global */
2032
+ return_embedding: VariantType<{
2033
+ /** Single return value for entire sequence */
2034
+ global: NullType;
2035
+ /** Return-to-go at each timestep */
2036
+ per_timestep: NullType;
2037
+ }>;
2038
+ }>;
2039
+ }>;
2040
+ /** Output mode (determines loss function) */
2041
+ output: VariantType<{
2042
+ /** Regression: MSE loss, no activation */
2043
+ regression: NullType;
2044
+ /** Binary: BCE loss, sigmoid activation */
2045
+ binary: StructType<{
2046
+ /** Optional per-position pos_weights for class imbalance [output_dim] */
2047
+ pos_weight: OptionType<ArrayType<FloatType>>;
2048
+ }>;
2049
+ /** Multiclass: CrossEntropy loss, softmax activation */
2050
+ multiclass: StructType<{
2051
+ /** Number of classes */
2052
+ n_classes: IntegerType;
2053
+ /** Optional per-class weights */
2054
+ class_weights: OptionType<ArrayType<FloatType>>;
2055
+ }>;
2056
+ /** Multi-head categorical: N independent CrossEntropy heads */
2057
+ multi_head: StructType<{
2058
+ /** Number of heads (e.g., 84 time slots) */
2059
+ n_heads: IntegerType;
2060
+ /** Classes per head (e.g., 4 bins) */
2061
+ n_classes_per_head: IntegerType;
2062
+ /** Optional class weights matrix (n_heads, n_classes) */
2063
+ class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
2064
+ }>;
2065
+ /**
2066
+ * Mixed output types per head.
2067
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2068
+ * Binary heads: 1 logit → sigmoid → BCE loss
2069
+ * Multiclass heads: n_classes logits → softmax → CE loss
2070
+ * Action vectors use one-hot encoding for multiclass heads.
2071
+ */
2072
+ multi_head_mixed: StructType<{
2073
+ /** Array of head configurations */
2074
+ heads: ArrayType<StructType<{
2075
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2076
+ head_type: VariantType<{
2077
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2078
+ binary: NullType;
2079
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2080
+ multiclass: StructType<{
2081
+ n_classes: IntegerType;
2082
+ }>;
2083
+ }>;
2084
+ /** Optional class weights for this head */
2085
+ class_weights: OptionType<ArrayType<FloatType>>;
2086
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2087
+ conditional_on: OptionType<IntegerType>;
2088
+ }>>;
2089
+ }>;
2090
+ }>;
2091
+ /** Learning rate (default: 1e-3) */
2092
+ learning_rate: OptionType<FloatType>;
2093
+ /** Maximum epochs (default: 100) */
2094
+ max_epochs: OptionType<IntegerType>;
2095
+ /** Early stopping patience (default: 10) */
2096
+ patience: OptionType<IntegerType>;
2097
+ /** Batch size (default: 32) */
2098
+ batch_size: OptionType<IntegerType>;
2099
+ /** Dropout rate (default: 0.1) */
2100
+ dropout: OptionType<FloatType>;
2101
+ /** Gradient clipping value (default: 1.0) */
2102
+ gradient_clip: OptionType<FloatType>;
2103
+ /** L2 regularization weight decay (default: 0) */
2104
+ weight_decay: OptionType<FloatType>;
2105
+ /** Random seed for reproducibility */
2106
+ random_state: OptionType<IntegerType>;
2107
+ /** Optional callback called each epoch */
2108
+ epoch_callback: OptionType<FunctionType<[IntegerType, FloatType, FloatType], NullType>>;
2109
+ }>], StructType<{
2110
+ /** Trained model blob */
2111
+ model: VariantType<{
2112
+ lightning: StructType<{
2113
+ /** Serialized model data (state_dict + hparams) */
2114
+ data: BlobType;
2115
+ /** Input dimension */
2116
+ n_features: IntegerType;
2117
+ /** Output dimension */
2118
+ output_dim: IntegerType;
2119
+ /** Architecture type */
2120
+ architecture_type: StringType;
2121
+ /** Output type */
2122
+ output_type: StringType;
2123
+ /** Latent dimension (autoencoder only) */
2124
+ latent_dim: OptionType<IntegerType>;
2125
+ }>;
2126
+ }>;
2127
+ /** Final training loss */
2128
+ train_loss: FloatType;
2129
+ /** Final validation loss */
2130
+ val_loss: FloatType;
2131
+ /** Best epoch (for early stopping) */
2132
+ best_epoch: IntegerType;
2133
+ }>>;
2134
+ /**
2135
+ * Generate action sequences from a Decision Transformer.
2136
+ *
2137
+ * Autoregressively generates actions conditioned on target returns
2138
+ * and state sequences.
2139
+ */
2140
+ readonly generateTrajectory: import("@elaraai/east").PlatformDefinition<[VariantType<{
2141
+ lightning: StructType<{
2142
+ /** Serialized model data (state_dict + hparams) */
2143
+ data: BlobType;
2144
+ /** Input dimension */
2145
+ n_features: IntegerType;
2146
+ /** Output dimension */
2147
+ output_dim: IntegerType;
2148
+ /** Architecture type */
2149
+ architecture_type: StringType;
2150
+ /** Output type */
2151
+ output_type: StringType;
2152
+ /** Latent dimension (autoencoder only) */
2153
+ latent_dim: OptionType<IntegerType>;
2154
+ }>;
2155
+ }>, ArrayType<ArrayType<ArrayType<FloatType>>>, ArrayType<FloatType>, StructType<{
2156
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
2157
+ temperature: FloatType;
2158
+ /** Whether to return probabilities or samples */
2159
+ return_probs: BooleanType;
2160
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
2161
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
2162
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
2163
+ temporal_mask: OptionType<ArrayType<FloatType>>;
2164
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
2165
+ head_configs: OptionType<ArrayType<StructType<{
2166
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2167
+ head_type: VariantType<{
2168
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2169
+ binary: NullType;
2170
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2171
+ multiclass: StructType<{
2172
+ n_classes: IntegerType;
2173
+ }>;
2174
+ }>;
2175
+ /** Optional class weights for this head */
2176
+ class_weights: OptionType<ArrayType<FloatType>>;
2177
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2178
+ conditional_on: OptionType<IntegerType>;
2179
+ }>>>;
2180
+ }>], ArrayType<ArrayType<ArrayType<FloatType>>>>;
1128
2181
  /**
1129
2182
  * Type definitions for Lightning functions.
1130
2183
  */
@@ -1153,6 +2206,31 @@ export declare const Lightning: {
1153
2206
  /** Optional class weights matrix (n_heads, n_classes) */
1154
2207
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1155
2208
  }>;
2209
+ /**
2210
+ * Mixed output types per head.
2211
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2212
+ * Binary heads: 1 logit → sigmoid → BCE loss
2213
+ * Multiclass heads: n_classes logits → softmax → CE loss
2214
+ * Action vectors use one-hot encoding for multiclass heads.
2215
+ */
2216
+ multi_head_mixed: StructType<{
2217
+ /** Array of head configurations */
2218
+ heads: ArrayType<StructType<{
2219
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2220
+ head_type: VariantType<{
2221
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2222
+ binary: NullType;
2223
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2224
+ multiclass: StructType<{
2225
+ n_classes: IntegerType;
2226
+ }>;
2227
+ }>;
2228
+ /** Optional class weights for this head */
2229
+ class_weights: OptionType<ArrayType<FloatType>>;
2230
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2231
+ conditional_on: OptionType<IntegerType>;
2232
+ }>>;
2233
+ }>;
1156
2234
  }>;
1157
2235
  readonly ArchitectureType: VariantType<{
1158
2236
  /** Simple MLP: input → hidden → output */
@@ -1225,6 +2303,36 @@ export declare const Lightning: {
1225
2303
  /** Optional condition dimension for conditional generation */
1226
2304
  condition_dim: OptionType<IntegerType>;
1227
2305
  }>;
2306
+ /**
2307
+ * Decision Transformer: return-conditioned sequence generation.
2308
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2309
+ * Predicts actions conditioned on desired return and state history.
2310
+ */
2311
+ decision_transformer: StructType<{
2312
+ /** Sequence length (timesteps) */
2313
+ sequence_length: IntegerType;
2314
+ /** State dimension per timestep */
2315
+ state_dim: IntegerType;
2316
+ /** Action dimension per timestep */
2317
+ action_dim: IntegerType;
2318
+ /** Model dimension (transformer hidden size) */
2319
+ d_model: IntegerType;
2320
+ /** Number of attention heads */
2321
+ n_attention_heads: IntegerType;
2322
+ /** Number of transformer layers */
2323
+ n_layers: IntegerType;
2324
+ /** Feedforward dimension (default: 4 * d_model) */
2325
+ d_ff: OptionType<IntegerType>;
2326
+ /** Dropout rate */
2327
+ dropout: OptionType<FloatType>;
2328
+ /** Whether return is per-timestep or global */
2329
+ return_embedding: VariantType<{
2330
+ /** Single return value for entire sequence */
2331
+ global: NullType;
2332
+ /** Return-to-go at each timestep */
2333
+ per_timestep: NullType;
2334
+ }>;
2335
+ }>;
1228
2336
  }>;
1229
2337
  readonly CellType: VariantType<{
1230
2338
  lstm: NullType;
@@ -1304,6 +2412,36 @@ export declare const Lightning: {
1304
2412
  /** Optional condition dimension for conditional generation */
1305
2413
  condition_dim: OptionType<IntegerType>;
1306
2414
  }>;
2415
+ /**
2416
+ * Decision Transformer: return-conditioned sequence generation.
2417
+ * Token layout: [R, s_0, a_0, s_1, a_1, ..., s_{T-1}, a_{T-1}]
2418
+ * Predicts actions conditioned on desired return and state history.
2419
+ */
2420
+ decision_transformer: StructType<{
2421
+ /** Sequence length (timesteps) */
2422
+ sequence_length: IntegerType;
2423
+ /** State dimension per timestep */
2424
+ state_dim: IntegerType;
2425
+ /** Action dimension per timestep */
2426
+ action_dim: IntegerType;
2427
+ /** Model dimension (transformer hidden size) */
2428
+ d_model: IntegerType;
2429
+ /** Number of attention heads */
2430
+ n_attention_heads: IntegerType;
2431
+ /** Number of transformer layers */
2432
+ n_layers: IntegerType;
2433
+ /** Feedforward dimension (default: 4 * d_model) */
2434
+ d_ff: OptionType<IntegerType>;
2435
+ /** Dropout rate */
2436
+ dropout: OptionType<FloatType>;
2437
+ /** Whether return is per-timestep or global */
2438
+ return_embedding: VariantType<{
2439
+ /** Single return value for entire sequence */
2440
+ global: NullType;
2441
+ /** Return-to-go at each timestep */
2442
+ per_timestep: NullType;
2443
+ }>;
2444
+ }>;
1307
2445
  }>;
1308
2446
  /** Output mode (determines loss function) */
1309
2447
  output: VariantType<{
@@ -1330,6 +2468,31 @@ export declare const Lightning: {
1330
2468
  /** Optional class weights matrix (n_heads, n_classes) */
1331
2469
  class_weights: OptionType<ArrayType<ArrayType<FloatType>>>;
1332
2470
  }>;
2471
+ /**
2472
+ * Mixed output types per head.
2473
+ * For Decision Transformer: combines binary (1 logit) and multiclass (n_classes logits) heads.
2474
+ * Binary heads: 1 logit → sigmoid → BCE loss
2475
+ * Multiclass heads: n_classes logits → softmax → CE loss
2476
+ * Action vectors use one-hot encoding for multiclass heads.
2477
+ */
2478
+ multi_head_mixed: StructType<{
2479
+ /** Array of head configurations */
2480
+ heads: ArrayType<StructType<{
2481
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2482
+ head_type: VariantType<{
2483
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2484
+ binary: NullType;
2485
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2486
+ multiclass: StructType<{
2487
+ n_classes: IntegerType;
2488
+ }>;
2489
+ }>;
2490
+ /** Optional class weights for this head */
2491
+ class_weights: OptionType<ArrayType<FloatType>>;
2492
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2493
+ conditional_on: OptionType<IntegerType>;
2494
+ }>>;
2495
+ }>;
1333
2496
  }>;
1334
2497
  /** Learning rate (default: 1e-3) */
1335
2498
  learning_rate: OptionType<FloatType>;
@@ -1403,6 +2566,61 @@ export declare const Lightning: {
1403
2566
  /** Group index per sample: [n_samples] */
1404
2567
  sample_groups: ArrayType<IntegerType>;
1405
2568
  }>;
2569
+ readonly GenerateConfigType: StructType<{
2570
+ /** Number of steps to generate */
2571
+ n_steps: IntegerType;
2572
+ /** Sampling temperature: 0.0 = argmax, > 0 = scaled sampling */
2573
+ temperature: FloatType;
2574
+ /** If true, return probabilities. If false, return samples. */
2575
+ return_probs: BooleanType;
2576
+ }>;
2577
+ readonly ReturnEmbeddingType: VariantType<{
2578
+ /** Single return value for entire sequence */
2579
+ global: NullType;
2580
+ /** Return-to-go at each timestep */
2581
+ per_timestep: NullType;
2582
+ }>;
2583
+ readonly HeadConfigType: StructType<{
2584
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2585
+ head_type: VariantType<{
2586
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2587
+ binary: NullType;
2588
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2589
+ multiclass: StructType<{
2590
+ n_classes: IntegerType;
2591
+ }>;
2592
+ }>;
2593
+ /** Optional class weights for this head */
2594
+ class_weights: OptionType<ArrayType<FloatType>>;
2595
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2596
+ conditional_on: OptionType<IntegerType>;
2597
+ }>;
2598
+ readonly TrajectoryGenerateConfigType: StructType<{
2599
+ /** Sampling temperature (0.0 = argmax, > 0 = stochastic) */
2600
+ temperature: FloatType;
2601
+ /** Whether to return probabilities or samples */
2602
+ return_probs: BooleanType;
2603
+ /** Optional constraint mask: (seq_len, action_dim) - FALSE disables action */
2604
+ action_constraints: OptionType<ArrayType<ArrayType<FloatType>>>;
2605
+ /** Optional temporal mask: (seq_len,) - FALSE marks invalid timesteps */
2606
+ temporal_mask: OptionType<ArrayType<FloatType>>;
2607
+ /** Optional head configs for multi_head_mixed output (enables proper multiclass sampling) */
2608
+ head_configs: OptionType<ArrayType<StructType<{
2609
+ /** Output type: binary (1 logit, sigmoid, BCE) or multiclass (n_classes logits, softmax, CE) */
2610
+ head_type: VariantType<{
2611
+ /** Single binary output: 1 logit, sigmoid, BCE loss */
2612
+ binary: NullType;
2613
+ /** Multi-class output: n_classes logits, softmax, CE loss */
2614
+ multiclass: StructType<{
2615
+ n_classes: IntegerType;
2616
+ }>;
2617
+ }>;
2618
+ /** Optional class weights for this head */
2619
+ class_weights: OptionType<ArrayType<FloatType>>;
2620
+ /** Optional: index of head this depends on (loss only computed when that head is 1) */
2621
+ conditional_on: OptionType<IntegerType>;
2622
+ }>>>;
2623
+ }>;
1406
2624
  };
1407
2625
  };
1408
2626
  //# sourceMappingURL=lightning.d.ts.map