Added layer backpropagation
This commit is contained in:
parent
9a1810775b
commit
fb49c794b2
33
layer.h
33
layer.h
@ -15,12 +15,20 @@ class Layer {
|
||||
Matrix activated_output;
|
||||
Matrix biases;
|
||||
|
||||
// Planning for back propagation
|
||||
// Each layer needs the derivative of Z with respect to W, derivative of A with respect to Z and derivative of loss with respect to A
|
||||
// Let's call them dzw, daz and dca
|
||||
Matrix daz;
|
||||
|
||||
static inline float Sigmoid(float);
|
||||
static inline float SigmoidPrime(float);
|
||||
|
||||
inline void Forward(); // Forward Pass with sigmoid
|
||||
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
|
||||
|
||||
inline void BackPropagate(Matrix);
|
||||
inline void BackPropagate(Matrix, Matrix, float (*activation)(float)); // To backpropagate, we need the derivative of loss with respect to A and the derivative of used activation function
|
||||
|
||||
inline void Feed(Matrix);
|
||||
|
||||
// Constructors
|
||||
@ -29,6 +37,31 @@ class Layer {
|
||||
Layer();
|
||||
};
|
||||
|
||||
void Layer::BackPropagate(Matrix dzw, Matrix dca, float (*derivative)(float)){
|
||||
// Calculate daz ; derivative of activation function
|
||||
this->daz = this->activated_output.Function(derivative);
|
||||
// this->daz.Print("daz");
|
||||
|
||||
// We need to transpose dzw and extend down
|
||||
// dzw.Print("dzw");
|
||||
dzw = dzw.Transpose().ExtendDown(dca.values.size());
|
||||
// dzw.Print("dzw extended transposed");
|
||||
|
||||
Matrix dcw = this->daz.Hadamard(&dca).ExtendRight(this->input.values.size());
|
||||
// dcw.Print("daz . dca");
|
||||
dcw = dcw.Hadamard(&dzw);
|
||||
// dcw.Print("daz . dca . dzw : DCW");
|
||||
|
||||
// this->weights.Print("weights");
|
||||
|
||||
// Apply dcw to weights
|
||||
float learning_rate = 0.1F;
|
||||
Matrix reduced_dcw = dcw.Multiply(learning_rate);
|
||||
// We SUBSTRACT the derivative of loss with respect to the weights.
|
||||
this->weights = this->weights.Substract(&reduced_dcw);
|
||||
// this->weights.Print("New weights");
|
||||
}
|
||||
|
||||
Layer::Layer(){
|
||||
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user