rxnn 0.1.5__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/rxt/models.py +7 -7
- rxnn/training/base.py +1 -1
- rxnn/training/bml.py +2 -2
- rxnn/training/callbacks.py +1 -2
- rxnn/transformers/attention.py +1 -1
- rxnn/transformers/layers.py +3 -3
- rxnn/transformers/models.py +3 -3
- {rxnn-0.1.5.dist-info → rxnn-0.1.6.dist-info}/METADATA +1 -1
- {rxnn-0.1.5.dist-info → rxnn-0.1.6.dist-info}/RECORD +11 -11
- {rxnn-0.1.5.dist-info → rxnn-0.1.6.dist-info}/LICENSE +0 -0
- {rxnn-0.1.5.dist-info → rxnn-0.1.6.dist-info}/WHEEL +0 -0
rxnn/rxt/models.py
CHANGED
@@ -2,13 +2,13 @@ import torch
|
|
2
2
|
from torch import nn
|
3
3
|
from typing import TypedDict, Union
|
4
4
|
from huggingface_hub import PyTorchModelHubMixin
|
5
|
-
from
|
6
|
-
from
|
7
|
-
from
|
8
|
-
from
|
9
|
-
from
|
10
|
-
from
|
11
|
-
from
|
5
|
+
from ..transformers.positional import RotaryPositionalEmbedding
|
6
|
+
from ..transformers.attention import init_attention
|
7
|
+
from ..transformers.layers import ReactiveTransformerLayer
|
8
|
+
from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
|
+
from ..transformers.ff import get_activation_layer
|
10
|
+
from ..memory.stm import ShortTermMemory
|
11
|
+
from ..utils import get_model_size
|
12
12
|
|
13
13
|
|
14
14
|
class RxTAlphaComponentConfig(TypedDict):
|
rxnn/training/base.py
CHANGED
@@ -7,7 +7,7 @@ from torch.utils.tensorboard import SummaryWriter
|
|
7
7
|
import torch.distributed as dist
|
8
8
|
from torch.nn.parallel import DistributedDataParallel
|
9
9
|
from typing import Callable
|
10
|
-
from callbacks import TrainerCallback
|
10
|
+
from .callbacks import TrainerCallback
|
11
11
|
|
12
12
|
|
13
13
|
class BaseTrainer(ABC):
|
rxnn/training/bml.py
CHANGED
@@ -5,8 +5,8 @@ import math
|
|
5
5
|
from huggingface_hub import PyTorchModelHubMixin
|
6
6
|
from typing import Union
|
7
7
|
import torch.distributed as dist
|
8
|
-
from
|
9
|
-
from
|
8
|
+
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
9
|
+
from ..training.base import BaseTrainer
|
10
10
|
|
11
11
|
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
12
12
|
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
rxnn/training/callbacks.py
CHANGED
@@ -3,10 +3,9 @@ import numpy as np
|
|
3
3
|
import torch
|
4
4
|
import torch.nn as nn
|
5
5
|
from typing import Union
|
6
|
-
from rxnn.utils import human_format
|
7
6
|
from torch.nn.parallel import DistributedDataParallel
|
8
7
|
from huggingface_hub import PyTorchModelHubMixin
|
9
|
-
|
8
|
+
from ..utils import human_format
|
10
9
|
|
11
10
|
class TrainerCallback:
|
12
11
|
def on_epoch_start(self, model: torch.nn.Module, epoch: int) -> None:
|
rxnn/transformers/attention.py
CHANGED
@@ -2,7 +2,7 @@ import torch
|
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
4
|
import math
|
5
|
-
from positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
5
|
+
from .positional import RotaryPositionalEmbedding, RelativePositionalEmbedding
|
6
6
|
|
7
7
|
|
8
8
|
class MultiHeadAttention(nn.Module):
|
rxnn/transformers/layers.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
|
-
from attention import MultiHeadAttention
|
4
|
-
from ff import FeedForward, GatedFeedForward
|
5
|
-
from moe import MoeFeedForward, GatedMoeFeedForward
|
3
|
+
from .attention import MultiHeadAttention
|
4
|
+
from .ff import FeedForward, GatedFeedForward
|
5
|
+
from .moe import MoeFeedForward, GatedMoeFeedForward
|
6
6
|
|
7
7
|
|
8
8
|
class ReactiveTransformerLayer(nn.Module):
|
rxnn/transformers/models.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
|
-
from positional import AbsolutePositionalEmbedding
|
4
|
-
from mask import create_causal_mask
|
5
|
-
from
|
3
|
+
from .positional import AbsolutePositionalEmbedding
|
4
|
+
from .mask import create_causal_mask
|
5
|
+
from ..memory.stm import ShortTermMemory
|
6
6
|
|
7
7
|
|
8
8
|
class ReactiveTransformerBase(nn.Module):
|
@@ -5,25 +5,25 @@ rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
5
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
6
6
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
7
7
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/rxt/models.py,sha256=
|
8
|
+
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
9
9
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
|
-
rxnn/training/base.py,sha256=
|
11
|
-
rxnn/training/bml.py,sha256=
|
12
|
-
rxnn/training/callbacks.py,sha256=
|
10
|
+
rxnn/training/base.py,sha256=EvR1qKE9O2BLJ5IkFxKMOdQPJoZ20j6FEZR1KbbY-vg,11031
|
11
|
+
rxnn/training/bml.py,sha256=pyK6aRLpXlPuLge6CQ9PD64Un57yUgbOpu8lUfTdV9k,14575
|
12
|
+
rxnn/training/callbacks.py,sha256=IyVJAJ0ggJmfIWBZnpzV9U08URYCeWIStK_wbx7m3pg,21090
|
13
13
|
rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
|
14
14
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
15
15
|
rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
|
16
16
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
|
-
rxnn/transformers/attention.py,sha256=
|
17
|
+
rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
|
18
18
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
19
|
-
rxnn/transformers/layers.py,sha256=
|
19
|
+
rxnn/transformers/layers.py,sha256=jdM7L0uOMO68aZiu9p6jba1Hx3aLGOChF1Zz-j4vJ5U,5364
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
|
-
rxnn/transformers/models.py,sha256=
|
21
|
+
rxnn/transformers/models.py,sha256=sLYMkVOWQ1NcM1evpCTUMucXvklySpeNT0IqpIGKmyc,6716
|
22
22
|
rxnn/transformers/moe.py,sha256=JQ5QSX4FS7S-fqB7-s1ZmJbPpOeD_Injn8o4vo7wGQE,4936
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
24
|
rxnn/transformers/sampler.py,sha256=wSz_1wNloqtuiix5w2Mcsj5NhaO9QlY0j__TVG7wJnM,3938
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.6.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.6.dist-info/METADATA,sha256=q5Lxgo6vMhmObnUk6FQx8KU-MAKVd7l8s3a4M6Sarng,14486
|
28
|
+
rxnn-0.1.6.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.6.dist-info/RECORD,,
|
File without changes
|
File without changes
|