Added layer backpropagation
This commit is contained in:
parent
9a1810775b
commit
fb49c794b2
33
layer.h
33
layer.h
@ -15,12 +15,20 @@ class Layer {
|
|||||||
Matrix activated_output;
|
Matrix activated_output;
|
||||||
Matrix biases;
|
Matrix biases;
|
||||||
|
|
||||||
|
// Planning for back propagation
|
||||||
|
// Each layer needs the derivative of Z with respect to W, derivative of A with respect to Z and derivative of loss with respect to A
|
||||||
|
// Let's call them dzw, daz and dca
|
||||||
|
Matrix daz;
|
||||||
|
|
||||||
static inline float Sigmoid(float);
|
static inline float Sigmoid(float);
|
||||||
static inline float SigmoidPrime(float);
|
static inline float SigmoidPrime(float);
|
||||||
|
|
||||||
inline void Forward(); // Forward Pass with sigmoid
|
inline void Forward(); // Forward Pass with sigmoid
|
||||||
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
|
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
|
||||||
|
|
||||||
|
inline void BackPropagate(Matrix);
|
||||||
|
inline void BackPropagate(Matrix, Matrix, float (*activation)(float)); // To backpropagate, we need the derivative of loss with respect to A and the derivative of used activation function
|
||||||
|
|
||||||
inline void Feed(Matrix);
|
inline void Feed(Matrix);
|
||||||
|
|
||||||
// Constructors
|
// Constructors
|
||||||
@ -29,6 +37,31 @@ class Layer {
|
|||||||
Layer();
|
Layer();
|
||||||
};
|
};
|
||||||
|
|
||||||
|
void Layer::BackPropagate(Matrix dzw, Matrix dca, float (*derivative)(float)){
|
||||||
|
// Calculate daz ; derivative of activation function
|
||||||
|
this->daz = this->activated_output.Function(derivative);
|
||||||
|
// this->daz.Print("daz");
|
||||||
|
|
||||||
|
// We need to transpose dzw and extend down
|
||||||
|
// dzw.Print("dzw");
|
||||||
|
dzw = dzw.Transpose().ExtendDown(dca.values.size());
|
||||||
|
// dzw.Print("dzw extended transposed");
|
||||||
|
|
||||||
|
Matrix dcw = this->daz.Hadamard(&dca).ExtendRight(this->input.values.size());
|
||||||
|
// dcw.Print("daz . dca");
|
||||||
|
dcw = dcw.Hadamard(&dzw);
|
||||||
|
// dcw.Print("daz . dca . dzw : DCW");
|
||||||
|
|
||||||
|
// this->weights.Print("weights");
|
||||||
|
|
||||||
|
// Apply dcw to weights
|
||||||
|
float learning_rate = 0.1F;
|
||||||
|
Matrix reduced_dcw = dcw.Multiply(learning_rate);
|
||||||
|
// We SUBSTRACT the derivative of loss with respect to the weights.
|
||||||
|
this->weights = this->weights.Substract(&reduced_dcw);
|
||||||
|
// this->weights.Print("New weights");
|
||||||
|
}
|
||||||
|
|
||||||
Layer::Layer(){
|
Layer::Layer(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user