1
0
forked from leto/LeMA

Adding now returns a Matrix object

This commit is contained in:
Leto 2024-12-24 14:05:08 +01:00
parent 6eb4b00c12
commit 78a8185e19
2 changed files with 26 additions and 19 deletions

View File

@ -1,4 +1,3 @@
#include<iostream>
#include "matrices.h" #include "matrices.h"
// using namespace std; // using namespace std;
@ -14,5 +13,7 @@ int main()
a.Print("A"); a.Print("A");
b.Print("B"); b.Print("B");
a.Add(&b).Print("A + B");
return 0; return 0;
} }

View File

@ -15,15 +15,15 @@ class Matrix{
inline void Set(float); inline void Set(float);
inline void Multiply(float); inline void Multiply(float);
inline Matrix Multiply(Matrix*); inline Matrix Multiply(const Matrix*);
inline void Hadamard(Matrix*); inline void Hadamard(const Matrix*);
inline void Add(float); inline Matrix Add(float);
inline void Add(Matrix*); inline Matrix Add(const Matrix*);
inline void Substract(float); inline void Substract(float);
inline void Substract(Matrix*); inline void Substract(const Matrix*);
inline void Print(std::string_view); inline void Print(std::string_view);
@ -32,10 +32,11 @@ class Matrix{
// --- Operators // --- Operators
// Assign // Assign
inline Matrix operator=(const Matrix*); inline Matrix operator=(const Matrix*);
inline Matrix operator+(const Matrix*);
// Constructors // Constructors
inline Matrix(int, int); inline Matrix(int, int);
inline Matrix(Matrix*); inline Matrix(const Matrix*);
}; };
Matrix Matrix::operator=(const Matrix* other){ Matrix Matrix::operator=(const Matrix* other){
@ -54,11 +55,11 @@ Matrix::Matrix(int rows, int cols){
} }
} }
Matrix::Matrix(Matrix* other){ Matrix::Matrix(const Matrix* other){
this->values = other->values; this->values = other->values;
} }
void Matrix::Hadamard(Matrix* other){ void Matrix::Hadamard(const Matrix* other){
// Matrices need to be the same size // Matrices need to be the same size
assertm(this->values.size() == other->values.size() && assertm(this->values.size() == other->values.size() &&
this->values[0].size() == other->values[0].size(), this->values[0].size() == other->values[0].size(),
@ -71,7 +72,7 @@ void Matrix::Hadamard(Matrix* other){
} }
// Multiply 2 matrices (AxB = this x other) // Multiply 2 matrices (AxB = this x other)
Matrix Matrix::Multiply(Matrix* other){ Matrix Matrix::Multiply(const Matrix* other){
// Matrices need to be of right size // Matrices need to be of right size
assertm(this->values[0].size() == other->values.size(),"Wrong matrix size"); assertm(this->values[0].size() == other->values.size(),"Wrong matrix size");
@ -91,20 +92,23 @@ Matrix Matrix::Multiply(Matrix* other){
} }
// Add 2 matrices // Add 2 matrices
void Matrix::Add(Matrix* other){ Matrix Matrix::Add(const Matrix* other){
// Matrices need to be the same size // Matrices need to be the same size
assertm(this->values.size() == other->values.size() && assertm(this->values.size() == other->values.size() &&
this->values[0].size() == other->values[0].size(), this->values[0].size() == other->values[0].size(),
"Wrong matrix size"); "Wrong matrix size");
for(int m = 0; m < this->values.size(); m++){
for(int n = 0; n < this->values[m].size(); n++){ Matrix result = this;
this->values[m][n] += other->values[m][n]; for(int m = 0; m < result.values.size(); m++){
for(int n = 0; n < result.values[m].size(); n++){
result.values[m][n] += other->values[m][n];
} }
} }
return result;
} }
// Substract 2 matrices // Substract 2 matrices
void Matrix::Substract(Matrix* other){ void Matrix::Substract(const Matrix* other){
// Matrices need to be the same size // Matrices need to be the same size
assertm(this->values.size() == other->values.size() && assertm(this->values.size() == other->values.size() &&
this->values[0].size() == other->values[0].size(), this->values[0].size() == other->values[0].size(),
@ -129,12 +133,14 @@ void Matrix::Print(std::string_view titre){
} }
// Add a constant value to every matrix case // Add a constant value to every matrix case
void Matrix::Add(float value){ Matrix Matrix::Add(float value){
for(int m = 0; m < this->values.size(); m++){ Matrix result = this;
for(int n = 0; n < this->values[m].size(); n++){ for(int m = 0; m < result.values.size(); m++){
this->values[m][n] += value; for(int n = 0; n < result.values[m].size(); n++){
result.values[m][n] += value;
} }
} }
return result;
} }
// Substract a constant value to every matrix case // Substract a constant value to every matrix case