1
0
forked from leto/LeMA

Added matrix multiplication

This commit is contained in:
Leto 2024-12-23 21:43:32 +01:00
parent 4f11d1e4fa
commit 2fe209fb2b
2 changed files with 30 additions and 7 deletions

View File

@ -6,16 +6,16 @@ int main()
{ {
srand(time(0)); srand(time(0));
Matrix a(3,3); Matrix a(2,1);
Matrix b(3,3); Matrix b(1,2);
a.Randomize(); a.Randomize();
b.Randomize(); b.Randomize();
a.Print("A"); a.Print("A");
b.Print("B"); b.Print("B");
a.Add(&b);
a.Print("A+B"); a.Multiply(&b).Print("A x B");
return 0; return 0;
} }

View File

@ -14,6 +14,7 @@ class Matrix{
void Set(float); void Set(float);
void Multiply(float); void Multiply(float);
Matrix Multiply(Matrix*);
void Add(float); void Add(float);
void Add(Matrix*); void Add(Matrix*);
@ -27,6 +28,7 @@ class Matrix{
// Constructors // Constructors
Matrix(int, int); Matrix(int, int);
Matrix(Matrix*);
}; };
// Constructs a zero matrix // Constructs a zero matrix
@ -40,6 +42,27 @@ Matrix::Matrix(int rows, int cols){
} }
} }
Matrix::Matrix(Matrix* other){
this->values = other->values;
}
// Multiply 2 matrices (AxB = this x other)
Matrix Matrix::Multiply(Matrix* other){
// Resulting size is this->M x other->N
Matrix result(this->values.size(), other->values[0].size());
for(int m = 0; m < result.values.size(); m++){
for(int n = 0; n < result.values[m].size(); n++){
// Sum multiplications
float buffer = 0.0F;
for(int i = 0; i < this->values[0].size(); i++){
buffer += this->values[m][i] * other->values[i][n];
}
result.values[m][n] = buffer;
}
}
return result;
}
// Add 2 matrices // Add 2 matrices
void Matrix::Add(Matrix* other){ void Matrix::Add(Matrix* other){
// Matrices need to be the same size // Matrices need to be the same size
@ -118,11 +141,11 @@ void Matrix::Transpose(){
// Invert matrix size // Invert matrix size
for(int m = 0; m < buffer[0].size(); m++){ for(int m = 0; m < buffer[0].size(); m++){
std::vector<float> line = {}; std::vector<float> row = {};
for(int n = 0; n < buffer.size(); n++){ for(int n = 0; n < buffer.size(); n++){
line.push_back(buffer[n][m]); row.push_back(buffer[n][m]);
} }
this->values.push_back(line); this->values.push_back(row);
} }
} }