diff --git a/Base/include/MatrixBase.h b/Base/include/MatrixBase.h index ba431c1..5bd4c1e 100644 --- a/Base/include/MatrixBase.h +++ b/Base/include/MatrixBase.h @@ -30,11 +30,17 @@ namespace FasTC { template - class MatrixBase : public VectorBase { - private: - typedef VectorBase Base; + class MatrixBase { + protected: + + // Vector representation + T mat[nRows * nCols]; + public: - static const int Size = Base::Size; + typedef T ScalarType; + static const int kNumRows = nRows; + static const int kNumCols = nCols; + static const int Size = kNumCols * kNumRows; // Constructors MatrixBase() { } @@ -45,16 +51,16 @@ namespace FasTC { } // Accessors - T &operator()(int idx) { return Base::operator()(idx); } - T &operator[](int idx) { return Base::operator[](idx); } - const T &operator()(int idx) const { return Base::operator()(idx); } - const T &operator[](int idx) const { return Base::operator[](idx); } + T &operator()(int idx) { return mat[idx]; } + T &operator[](int idx) { return mat[idx]; } + const T &operator()(int idx) const { return mat[idx]; } + const T &operator[](int idx) const { return mat[idx]; } T &operator()(int r, int c) { return (*this)[r * nCols + c]; } const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; } // Allow casts to the respective array representation... - operator const T *() const { return this->vec; } + operator const T *() const { return this->mat; } MatrixBase &operator=(const T *v) { for(int i = 0; i < Size; i++) (*this)[i] = v[i]; @@ -66,7 +72,7 @@ namespace FasTC { operator MatrixBase<_T, nRows, nCols>() const { MatrixBase<_T, nRows, nCols> ret; for(int i = 0; i < Size; i++) { - ret[i] = static_cast<_T>(this->vec[i]); + ret[i] = static_cast<_T>(mat[i]); } return ret; } @@ -87,8 +93,20 @@ namespace FasTC { // Vector multiplication -- treat vectors as Nx1 matrices... template - VectorBase operator*(const VectorBase<_T, nCols> &v) { + VectorBase MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const { VectorBase result; + for(int j = 0; j < nCols; j++) { + result(j) = 0; + for(int r = 0; r < nRows; r++) { + result(j) += (*this)(r, j) * v(r); + } + } + return result; + } + + template + VectorBase MultiplyVectorRight(const VectorBase<_T, nCols> &v) const { + VectorBase result; for(int r = 0; r < nRows; r++) { result(r) = 0; for(int j = 0; j < nCols; j++) { @@ -111,14 +129,88 @@ namespace FasTC { // Double dot product template - T DDot(const MatrixBase<_T, nRows, nCols> &m) { + T DDot(const MatrixBase<_T, nRows, nCols> &m) const { T result = 0; for(int i = 0; i < Size; i++) { result += (*this)[i] * m[i]; } return result; } + }; + template + class VectorTraits > { + public: + static const EVectorType kVectorType = eVectorType_Matrix; + }; + + #define REGISTER_MATRIX_TYPE(TYPE) \ + template<> \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Matrix; \ + } + + #define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE) \ + template \ + class VectorTraits< TYPE > { \ + public: \ + static const EVectorType kVectorType = eVectorType_Matrix; \ + } + + // Define matrix multiplication for * operator + template + class MultSwitch< + eVectorType_Matrix, + eVectorType_Vector, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef VectorBase ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); } + }; + + template + class MultSwitch< + eVectorType_Vector, + eVectorType_Matrix, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef VectorBase ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); } + }; + + template + class MultSwitch< + eVectorType_Matrix, + eVectorType_Matrix, + TypeOne, TypeTwo> { + private: + const TypeOne &m_A; + const TypeTwo &m_B; + + public: + typedef MatrixBase ResultType; + + MultSwitch(const TypeOne &a, const TypeTwo &b) + : m_A(a), m_B(b) { } + + ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); } }; // Outer product... diff --git a/Base/test/TestMatrix.cpp b/Base/test/TestMatrix.cpp index 08ba2bb..2e26cf5 100644 --- a/Base/test/TestMatrix.cpp +++ b/Base/test/TestMatrix.cpp @@ -158,9 +158,9 @@ TEST(MatrixBase, MatrixMultiplication) { TEST(MatrixBase, Transposition) { FasTC::MatrixBase a; - a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0; - a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3; - a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5; + a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0; + a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3; + a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5; FasTC::MatrixBase b = a.Transpose(); @@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) { } TEST(MatrixBase, VectorMultiplication) { - // Stub - EXPECT_EQ(0, 1); + + FasTC::MatrixBase a; + a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0; + a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3; + a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5; + + FasTC::VectorBase v; + for(int i = 0; i < 5; i++) v[i] = i + 1; + + FasTC::VectorBase u = a * v; + EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4)); + EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5)); + EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5)); + + ///// + + for(int i = 0; i < 3; i++) u[i] = i + 1; + v = u * a; + + EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3)); + EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3)); + EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3)); + EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3)); + EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3)); } TEST(MatrixSquare, Constructors) {