mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-24 13:31:07 +00:00
Merge pull request #115 from RiptideBo/stephen_branch
add neuralnetwork_bp3.py
This commit is contained in:
commit
a38e684a73
|
@ -9,7 +9,7 @@ BP neural network with three layers
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
class Bpnw():
|
class Bpnn():
|
||||||
|
|
||||||
def __init__(self,n_layer1,n_layer2,n_layer3,rate_w=0.3,rate_t=0.3):
|
def __init__(self,n_layer1,n_layer2,n_layer3,rate_w=0.3,rate_t=0.3):
|
||||||
'''
|
'''
|
||||||
|
@ -38,7 +38,7 @@ class Bpnw():
|
||||||
def do_round(self,x):
|
def do_round(self,x):
|
||||||
return round(x, 3)
|
return round(x, 3)
|
||||||
|
|
||||||
def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy,draw_e = bool):
|
def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy, draw_e=False):
|
||||||
'''
|
'''
|
||||||
:param patterns: the number of patterns
|
:param patterns: the number of patterns
|
||||||
:param data_train: training data x; numpy.ndarray
|
:param data_train: training data x; numpy.ndarray
|
||||||
|
@ -49,9 +49,9 @@ class Bpnw():
|
||||||
'''
|
'''
|
||||||
data_train = np.asarray(data_train)
|
data_train = np.asarray(data_train)
|
||||||
data_teach = np.asarray(data_teach)
|
data_teach = np.asarray(data_teach)
|
||||||
print('-------------------Start Training-------------------------')
|
# print('-------------------Start Training-------------------------')
|
||||||
print(' - - Shape: Train_Data ',np.shape(data_train))
|
# print(' - - Shape: Train_Data ',np.shape(data_train))
|
||||||
print(' - - Shape: Teach_Data ',np.shape(data_teach))
|
# print(' - - Shape: Teach_Data ',np.shape(data_teach))
|
||||||
rp = 0
|
rp = 0
|
||||||
all_mse = []
|
all_mse = []
|
||||||
mse = 10000
|
mse = 10000
|
||||||
|
@ -95,9 +95,9 @@ class Bpnw():
|
||||||
plt.ylabel('All_mse')
|
plt.ylabel('All_mse')
|
||||||
plt.grid(True,alpha = 0.7)
|
plt.grid(True,alpha = 0.7)
|
||||||
plt.show()
|
plt.show()
|
||||||
print('------------------Training Complished---------------------')
|
# print('------------------Training Complished---------------------')
|
||||||
print(' - - Training epoch: ', rp, ' - - Mse: %.6f'%mse)
|
# print(' - - Training epoch: ', rp, ' - - Mse: %.6f'%mse)
|
||||||
print(' - - Last Output: ', final_out3)
|
# print(' - - Last Output: ', final_out3)
|
||||||
if draw_e:
|
if draw_e:
|
||||||
draw_error()
|
draw_error()
|
||||||
|
|
||||||
|
@ -108,9 +108,9 @@ class Bpnw():
|
||||||
'''
|
'''
|
||||||
data_test = np.asarray(data_test)
|
data_test = np.asarray(data_test)
|
||||||
produce_out = []
|
produce_out = []
|
||||||
print('-------------------Start Testing-------------------------')
|
# print('-------------------Start Testing-------------------------')
|
||||||
print(' - - Shape: Test_Data ',np.shape(data_test))
|
# print(' - - Shape: Test_Data ',np.shape(data_test))
|
||||||
print(np.shape(data_test))
|
# print(np.shape(data_test))
|
||||||
for g in range(np.shape(data_test)[0]):
|
for g in range(np.shape(data_test)[0]):
|
||||||
|
|
||||||
net_i = data_test[g]
|
net_i = data_test[g]
|
||||||
|
@ -127,8 +127,26 @@ class Bpnw():
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
#I will fish the mian function later
|
#example data
|
||||||
pass
|
data_x = [[1,2,3,4],
|
||||||
|
[5,6,7,8],
|
||||||
|
[2,2,3,4],
|
||||||
|
[7,7,8,8]]
|
||||||
|
data_y = [[1,0,0,0],
|
||||||
|
[0,1,0,0],
|
||||||
|
[0,0,1,0],
|
||||||
|
[0,0,0,1]]
|
||||||
|
|
||||||
|
test_x = [[1,2,3,4],
|
||||||
|
[3,2,3,4]]
|
||||||
|
|
||||||
|
#building network model
|
||||||
|
model = Bpnn(4,10,4)
|
||||||
|
#training the model
|
||||||
|
model.trian(patterns=4,data_train=data_x,data_teach=data_y,
|
||||||
|
n_repeat=100,error_accuracy=0.1,draw_e=True)
|
||||||
|
#predicting data
|
||||||
|
model.predict(test_x)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user