Compare commits
No commits in common. "a78b3ef56911e57cd75c531e5ca15b36188563e9" and "9a1810775b30fd5f9baf6e935c9b351264e92395" have entirely different histories.
a78b3ef569
...
9a1810775b
2
.gitignore
vendored
2
.gitignore
vendored
@ -12,5 +12,3 @@ main.exe
|
|||||||
network.exe
|
network.exe
|
||||||
build/Debug/main.o
|
build/Debug/main.o
|
||||||
build/Debug/outDebug.exe
|
build/Debug/outDebug.exe
|
||||||
matrices.exe
|
|
||||||
tempCodeRunnerFile.cpp
|
|
||||||
|
33
layer.h
33
layer.h
@ -15,20 +15,12 @@ class Layer {
|
|||||||
Matrix activated_output;
|
Matrix activated_output;
|
||||||
Matrix biases;
|
Matrix biases;
|
||||||
|
|
||||||
// Planning for back propagation
|
|
||||||
// Each layer needs the derivative of Z with respect to W, derivative of A with respect to Z and derivative of loss with respect to A
|
|
||||||
// Let's call them dzw, daz and dca
|
|
||||||
Matrix daz;
|
|
||||||
|
|
||||||
static inline float Sigmoid(float);
|
static inline float Sigmoid(float);
|
||||||
static inline float SigmoidPrime(float);
|
static inline float SigmoidPrime(float);
|
||||||
|
|
||||||
inline void Forward(); // Forward Pass with sigmoid
|
inline void Forward(); // Forward Pass with sigmoid
|
||||||
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
|
inline void Forward(float (*activation)(float)); // Forward Pass with custom activation function
|
||||||
|
|
||||||
inline void BackPropagate(Matrix);
|
|
||||||
inline void BackPropagate(Matrix, Matrix, float (*activation)(float)); // To backpropagate, we need the derivative of loss with respect to A and the derivative of used activation function
|
|
||||||
|
|
||||||
inline void Feed(Matrix);
|
inline void Feed(Matrix);
|
||||||
|
|
||||||
// Constructors
|
// Constructors
|
||||||
@ -37,31 +29,6 @@ class Layer {
|
|||||||
Layer();
|
Layer();
|
||||||
};
|
};
|
||||||
|
|
||||||
void Layer::BackPropagate(Matrix dzw, Matrix dca, float (*derivative)(float)){
|
|
||||||
// Calculate daz ; derivative of activation function
|
|
||||||
this->daz = this->activated_output.Function(derivative);
|
|
||||||
// this->daz.Print("daz");
|
|
||||||
|
|
||||||
// We need to transpose dzw and extend down
|
|
||||||
// dzw.Print("dzw");
|
|
||||||
dzw = dzw.Transpose().ExtendDown(dca.values.size());
|
|
||||||
// dzw.Print("dzw extended transposed");
|
|
||||||
|
|
||||||
Matrix dcw = this->daz.Hadamard(&dca).ExtendRight(this->input.values.size());
|
|
||||||
// dcw.Print("daz . dca");
|
|
||||||
dcw = dcw.Hadamard(&dzw);
|
|
||||||
// dcw.Print("daz . dca . dzw : DCW");
|
|
||||||
|
|
||||||
// this->weights.Print("weights");
|
|
||||||
|
|
||||||
// Apply dcw to weights
|
|
||||||
float learning_rate = 0.1F;
|
|
||||||
Matrix reduced_dcw = dcw.Multiply(learning_rate);
|
|
||||||
// We SUBSTRACT the derivative of loss with respect to the weights.
|
|
||||||
this->weights = this->weights.Substract(&reduced_dcw);
|
|
||||||
// this->weights.Print("New weights");
|
|
||||||
}
|
|
||||||
|
|
||||||
Layer::Layer(){
|
Layer::Layer(){
|
||||||
|
|
||||||
}
|
}
|
||||||
|
40
matrices.h
40
matrices.h
@ -9,8 +9,9 @@
|
|||||||
#define assertm(exp, msg) assert((void(msg), exp))
|
#define assertm(exp, msg) assert((void(msg), exp))
|
||||||
|
|
||||||
class Matrix{
|
class Matrix{
|
||||||
public:
|
private:
|
||||||
std::vector<std::vector<float>> values;
|
std::vector<std::vector<float>> values;
|
||||||
|
public:
|
||||||
inline void Randomize();
|
inline void Randomize();
|
||||||
inline void Randomize(float, float);
|
inline void Randomize(float, float);
|
||||||
|
|
||||||
@ -21,7 +22,7 @@ class Matrix{
|
|||||||
inline Matrix Multiply(float);
|
inline Matrix Multiply(float);
|
||||||
inline Matrix Multiply(const Matrix*);
|
inline Matrix Multiply(const Matrix*);
|
||||||
|
|
||||||
inline Matrix Hadamard(const Matrix*);
|
inline void Hadamard(const Matrix*);
|
||||||
|
|
||||||
inline Matrix Add(float);
|
inline Matrix Add(float);
|
||||||
inline Matrix Add(const Matrix*);
|
inline Matrix Add(const Matrix*);
|
||||||
@ -31,9 +32,6 @@ class Matrix{
|
|||||||
|
|
||||||
inline Matrix Function(float (*f)(float));
|
inline Matrix Function(float (*f)(float));
|
||||||
|
|
||||||
inline Matrix ExtendRight(int);
|
|
||||||
inline Matrix ExtendDown(int);
|
|
||||||
|
|
||||||
inline void Print(std::string_view);
|
inline void Print(std::string_view);
|
||||||
|
|
||||||
inline Matrix Transpose();
|
inline Matrix Transpose();
|
||||||
@ -58,28 +56,6 @@ Matrix::Matrix(){
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix Matrix::ExtendRight(int new_size){
|
|
||||||
// Extend the matrix to the right
|
|
||||||
Matrix result(this->values.size(), new_size);
|
|
||||||
for(int n = 0; n < result.values.size(); n++){
|
|
||||||
for(int m = 0; m < result.values[n].size(); m++){
|
|
||||||
result.values[n][m] = this->values[n][0];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
Matrix Matrix::ExtendDown(int new_size){
|
|
||||||
// Extend the matrix down
|
|
||||||
Matrix result(new_size, this->values[0].size());
|
|
||||||
for(int n = 0; n < result.values.size(); n++){
|
|
||||||
for(int m = 0; m < result.values[n].size(); m++){
|
|
||||||
result.values[n][m] = this->values[0][m];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return result;
|
|
||||||
}
|
|
||||||
|
|
||||||
Matrix Matrix::operator=(const Matrix* other){
|
Matrix Matrix::operator=(const Matrix* other){
|
||||||
return this->Swap(other);
|
return this->Swap(other);
|
||||||
}
|
}
|
||||||
@ -139,18 +115,16 @@ Matrix Matrix::Swap(const Matrix* other){
|
|||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
Matrix Matrix::Hadamard(const Matrix* other){
|
void Matrix::Hadamard(const Matrix* other){
|
||||||
// Matrices need to be the same size
|
// Matrices need to be the same size
|
||||||
assertm(this->values.size() == other->values.size() &&
|
assertm(this->values.size() == other->values.size() &&
|
||||||
this->values[0].size() == other->values[0].size(),
|
this->values[0].size() == other->values[0].size(),
|
||||||
"Matrices need to be the same size");
|
"Matrices need to be the same size");
|
||||||
Matrix result = this;
|
for(int m = 0; m < this->values.size(); m++){
|
||||||
for(int m = 0; m < result.values.size(); m++){
|
for(int n = 0; n < this->values[m].size(); n++){
|
||||||
for(int n = 0; n < result.values[m].size(); n++){
|
this->values[m][n] *= other->values[m][n];
|
||||||
result.values[m][n] = this->values[m][n] * other->values[m][n];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Multiply 2 matrices (AxB = this x other)
|
// Multiply 2 matrices (AxB = this x other)
|
||||||
|
16
network.h
16
network.h
@ -15,27 +15,11 @@ class Network {
|
|||||||
|
|
||||||
inline void Forward();
|
inline void Forward();
|
||||||
|
|
||||||
inline void BackPropagate(Matrix);
|
|
||||||
|
|
||||||
// Constructors
|
// Constructors
|
||||||
// Input size, Array of hidden sizes, Output size
|
// Input size, Array of hidden sizes, Output size
|
||||||
Network(int, std::vector<int>, int);
|
Network(int, std::vector<int>, int);
|
||||||
};
|
};
|
||||||
|
|
||||||
void Network::BackPropagate(Matrix target){
|
|
||||||
// Calculate derivative of loss in respect to A (dca) for output layer
|
|
||||||
// loss = (A - Y)^2
|
|
||||||
// derivative = 2(A - Y)
|
|
||||||
Matrix loss = this->output_layer.activated_output.Substract(&target);
|
|
||||||
loss = loss.Hadamard(&loss);
|
|
||||||
// loss.Print("Loss");
|
|
||||||
Matrix dca = this->output_layer.activated_output.Substract(&target);
|
|
||||||
dca = dca.Multiply(2.0F);
|
|
||||||
// dca.Print("DCA");
|
|
||||||
|
|
||||||
this->output_layer.BackPropagate(this->hidden_layers[this->hidden_layers.size() - 1].activated_output, dca, &Layer::SigmoidPrime);
|
|
||||||
}
|
|
||||||
|
|
||||||
Network::Network(int input_size, std::vector<int> hidden_sizes, int output_size){
|
Network::Network(int input_size, std::vector<int> hidden_sizes, int output_size){
|
||||||
this->input = Matrix(input_size, 1);
|
this->input = Matrix(input_size, 1);
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user