mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-01-18 08:17:01 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
3b8848430c
commit
4f573e0d8d
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
|
||||
Name - - sliding_window_attention.py
|
||||
Goal - - Implement a neural network architecture using sliding
|
||||
Goal - - Implement a neural network architecture using sliding
|
||||
window attention for sequence modeling tasks.
|
||||
Detail: Total 5 layers neural network
|
||||
* Input layer
|
||||
|
@ -12,11 +12,11 @@ Author: Stephen Lee
|
|||
Github: 245885195@qq.com
|
||||
Date: 2024.10.20
|
||||
References:
|
||||
1. Choromanska, A., et al. (2020). "On the Importance of
|
||||
Initialization and Momentum in Deep Learning." *Proceedings
|
||||
1. Choromanska, A., et al. (2020). "On the Importance of
|
||||
Initialization and Momentum in Deep Learning." *Proceedings
|
||||
of the 37th International Conference on Machine Learning*.
|
||||
2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast
|
||||
Autoregressive Transformers with Linear Attention."
|
||||
2. Dai, Z., et al. (2020). "Transformers are RNNs: Fast
|
||||
Autoregressive Transformers with Linear Attention."
|
||||
*arXiv preprint arXiv:2006.16236*.
|
||||
3. [Attention Mechanisms in Neural Networks](https://en.wikipedia.org/wiki/Attention_(machine_learning))
|
||||
- - - - - -- - - - - - - - - - - - - - - - - - - - - - -
|
||||
|
@ -28,7 +28,7 @@ import numpy as np
|
|||
class SlidingWindowAttention:
|
||||
"""Sliding Window Attention Module.
|
||||
|
||||
This class implements a sliding window attention mechanism where
|
||||
This class implements a sliding window attention mechanism where
|
||||
the model attends to a fixed-size window of context around each token.
|
||||
|
||||
Attributes:
|
||||
|
@ -54,13 +54,13 @@ class SlidingWindowAttention:
|
|||
Forward pass for the sliding window attention.
|
||||
|
||||
Args:
|
||||
input_tensor (np.ndarray): Input tensor of shape (batch_size,
|
||||
input_tensor (np.ndarray): Input tensor of shape (batch_size,
|
||||
seq_length, embed_dim).
|
||||
|
||||
Returns:
|
||||
np.ndarray: Output tensor of shape (batch_size, seq_length, embed_dim).
|
||||
|
||||
>>> x = np.random.randn(2, 10, 4) # Batch size 2, sequence
|
||||
>>> x = np.random.randn(2, 10, 4) # Batch size 2, sequence
|
||||
>>> attention = SlidingWindowAttention(embed_dim=4, window_size=3)
|
||||
>>> output = attention.forward(x)
|
||||
>>> output.shape
|
||||
|
@ -95,7 +95,7 @@ if __name__ == "__main__":
|
|||
|
||||
# usage
|
||||
rng = np.random.default_rng()
|
||||
x = rng.standard_normal((2, 10, 4)) # Batch size 2,
|
||||
x = rng.standard_normal((2, 10, 4)) # Batch size 2,
|
||||
attention = SlidingWindowAttention(embed_dim=4, window_size=3)
|
||||
output = attention.forward(x)
|
||||
print(output)
|
||||
|
|
Loading…
Reference in New Issue
Block a user