mirror of
https://github.com/TheAlgorithms/Python.git
synced 2025-04-06 05:45:53 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
parent
3d9b893ee0
commit
5c186b16e8
@ -97,32 +97,22 @@ class LSTM:
|
||||
Initialize the weights and biases for the LSTM network.
|
||||
"""
|
||||
|
||||
self.wf = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.wf = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
||||
self.bf = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wi = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.wi = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
||||
self.bi = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wc = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.wc = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
||||
self.bc = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wo = self.init_weights(
|
||||
self.char_size + self.hidden_dim, self.hidden_dim
|
||||
)
|
||||
self.wo = self.init_weights(self.char_size + self.hidden_dim, self.hidden_dim)
|
||||
self.bo = np.zeros((self.hidden_dim, 1))
|
||||
|
||||
self.wy = self.init_weights(self.hidden_dim, self.char_size)
|
||||
self.by = np.zeros((self.char_size, 1))
|
||||
|
||||
def init_weights(
|
||||
self, input_dim: int, output_dim: int
|
||||
) -> np.ndarray:
|
||||
def init_weights(self, input_dim: int, output_dim: int) -> np.ndarray:
|
||||
"""
|
||||
Initialize weights with random values.
|
||||
|
||||
@ -367,7 +357,6 @@ class LSTM:
|
||||
print(f"Accuracy: {round(accuracy * 100 / len(self.train_X), 2)}%")
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
data = """Long Short-Term Memory (LSTM) networks are a type
|
||||
of recurrent neural network (RNN) capable of learning "
|
||||
|
Loading…
x
Reference in New Issue
Block a user