mirror of
https://github.com/TheAlgorithms/Python.git
synced 2024-11-30 16:31:08 +00:00
improve
This commit is contained in:
parent
e1befed976
commit
53b6fe15c9
|
@ -9,7 +9,7 @@ BP neural network with three layers
|
|||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
class Bpnw():
|
||||
class Bpnn():
|
||||
|
||||
def __init__(self,n_layer1,n_layer2,n_layer3,rate_w=0.3,rate_t=0.3):
|
||||
'''
|
||||
|
@ -38,7 +38,7 @@ class Bpnw():
|
|||
def do_round(self,x):
|
||||
return round(x, 3)
|
||||
|
||||
def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy,draw_e = bool):
|
||||
def trian(self,patterns,data_train, data_teach, n_repeat, error_accuracy, draw_e=False):
|
||||
'''
|
||||
:param patterns: the number of patterns
|
||||
:param data_train: training data x; numpy.ndarray
|
||||
|
@ -127,8 +127,26 @@ class Bpnw():
|
|||
|
||||
|
||||
def main():
|
||||
#I will fish the mian function later
|
||||
pass
|
||||
#example data
|
||||
data_x = [[1,2,3,4],
|
||||
[5,6,7,8],
|
||||
[2,2,3,4],
|
||||
[7,7,8,8]]
|
||||
data_y = [[1,0,0,0],
|
||||
[0,1,0,0],
|
||||
[0,0,1,0],
|
||||
[0,0,0,1]]
|
||||
|
||||
test_x = [[1,2,3,4],
|
||||
[3,2,3,4]]
|
||||
|
||||
#building network model
|
||||
model = Bpnn(4,10,4)
|
||||
#training the model
|
||||
model.trian(patterns=4,data_train=data_x,data_teach=data_y,
|
||||
n_repeat=100,error_accuracy=0.1,draw_e=True)
|
||||
#predicting data
|
||||
model.predict(test_x)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user