Added layer backpropagation

This commit is contained in:
LeLeLeLeto 2024-12-31 17:21:57 +01:00
parent 9a1810775b
commit fb49c794b2

33
layer.h
View File

@ -15,12 +15,20 @@ class Layer {
Matrix activated_output;
Matrix biases;
// Planning for back propagation
// Each layer needs the derivative of Z with respect to W, derivative of A with respect to Z and derivative of loss with respect to A
// Let's call them dzw, daz and dca
Matrix daz;
static inline float Sigmoid(float);
static inline float SigmoidPrime(float);
inline void Forward(); // Forward Pass with sigmoid
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
inline void BackPropagate(Matrix);
inline void BackPropagate(Matrix, Matrix, float (*activation)(float)); // To backpropagate, we need the derivative of loss with respect to A and the derivative of used activation function
inline void Feed(Matrix);
// Constructors
@ -29,6 +37,31 @@ class Layer {
Layer();
};
void Layer::BackPropagate(Matrix dzw, Matrix dca, float (*derivative)(float)){
// Calculate daz ; derivative of activation function
this->daz = this->activated_output.Function(derivative);
// this->daz.Print("daz");
// We need to transpose dzw and extend down
// dzw.Print("dzw");
dzw = dzw.Transpose().ExtendDown(dca.values.size());
// dzw.Print("dzw extended transposed");
Matrix dcw = this->daz.Hadamard(&dca).ExtendRight(this->input.values.size());
// dcw.Print("daz . dca");
dcw = dcw.Hadamard(&dzw);
// dcw.Print("daz . dca . dzw : DCW");
// this->weights.Print("weights");
// Apply dcw to weights
float learning_rate = 0.1F;
Matrix reduced_dcw = dcw.Multiply(learning_rate);
// We SUBSTRACT the derivative of loss with respect to the weights.
this->weights = this->weights.Substract(&reduced_dcw);
// this->weights.Print("New weights");
}
Layer::Layer(){
}