From fb49c794b2ae3e31a20b9eb29ae5501672799573 Mon Sep 17 00:00:00 2001 From: LeLeLeLeto Date: Tue, 31 Dec 2024 17:21:57 +0100 Subject: [PATCH] Added layer backpropagation --- layer.h | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/layer.h b/layer.h index 529cc71..8294d13 100644 --- a/layer.h +++ b/layer.h @@ -15,12 +15,20 @@ class Layer { Matrix activated_output; Matrix biases; + // Planning for back propagation + // Each layer needs the derivative of Z with respect to W, derivative of A with respect to Z and derivative of loss with respect to A + // Let's call them dzw, daz and dca + Matrix daz; + static inline float Sigmoid(float); static inline float SigmoidPrime(float); inline void Forward(); // Forward Pass with sigmoid inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function + inline void BackPropagate(Matrix); + inline void BackPropagate(Matrix, Matrix, float (*activation)(float)); // To backpropagate, we need the derivative of loss with respect to A and the derivative of used activation function + inline void Feed(Matrix); // Constructors @@ -29,6 +37,31 @@ class Layer { Layer(); }; +void Layer::BackPropagate(Matrix dzw, Matrix dca, float (*derivative)(float)){ + // Calculate daz ; derivative of activation function + this->daz = this->activated_output.Function(derivative); + // this->daz.Print("daz"); + + // We need to transpose dzw and extend down + // dzw.Print("dzw"); + dzw = dzw.Transpose().ExtendDown(dca.values.size()); + // dzw.Print("dzw extended transposed"); + + Matrix dcw = this->daz.Hadamard(&dca).ExtendRight(this->input.values.size()); + // dcw.Print("daz . dca"); + dcw = dcw.Hadamard(&dzw); + // dcw.Print("daz . dca . dzw : DCW"); + + // this->weights.Print("weights"); + + // Apply dcw to weights + float learning_rate = 0.1F; + Matrix reduced_dcw = dcw.Multiply(learning_rate); + // We SUBSTRACT the derivative of loss with respect to the weights. + this->weights = this->weights.Substract(&reduced_dcw); + // this->weights.Print("New weights"); +} + Layer::Layer(){ }