mirror of
https://github.com/yuzu-emu/FasTC
synced 2024-11-26 01:07:59 +00:00
Add matrix multiplication infrastructure
This commit is contained in:
parent
05eeb09f36
commit
8b9e8cd9b5
2 changed files with 131 additions and 17 deletions
|
@ -30,11 +30,17 @@
|
|||
namespace FasTC {
|
||||
|
||||
template <typename T, const int nRows, const int nCols>
|
||||
class MatrixBase : public VectorBase<T, nRows * nCols> {
|
||||
private:
|
||||
typedef VectorBase<T, nRows * nCols> Base;
|
||||
class MatrixBase {
|
||||
protected:
|
||||
|
||||
// Vector representation
|
||||
T mat[nRows * nCols];
|
||||
|
||||
public:
|
||||
static const int Size = Base::Size;
|
||||
typedef T ScalarType;
|
||||
static const int kNumRows = nRows;
|
||||
static const int kNumCols = nCols;
|
||||
static const int Size = kNumCols * kNumRows;
|
||||
|
||||
// Constructors
|
||||
MatrixBase() { }
|
||||
|
@ -45,16 +51,16 @@ namespace FasTC {
|
|||
}
|
||||
|
||||
// Accessors
|
||||
T &operator()(int idx) { return Base::operator()(idx); }
|
||||
T &operator[](int idx) { return Base::operator[](idx); }
|
||||
const T &operator()(int idx) const { return Base::operator()(idx); }
|
||||
const T &operator[](int idx) const { return Base::operator[](idx); }
|
||||
T &operator()(int idx) { return mat[idx]; }
|
||||
T &operator[](int idx) { return mat[idx]; }
|
||||
const T &operator()(int idx) const { return mat[idx]; }
|
||||
const T &operator[](int idx) const { return mat[idx]; }
|
||||
|
||||
T &operator()(int r, int c) { return (*this)[r * nCols + c]; }
|
||||
const T &operator() (int r, int c) const { return (*this)[r * nCols + c]; }
|
||||
|
||||
// Allow casts to the respective array representation...
|
||||
operator const T *() const { return this->vec; }
|
||||
operator const T *() const { return this->mat; }
|
||||
MatrixBase<T, nRows, nCols> &operator=(const T *v) {
|
||||
for(int i = 0; i < Size; i++)
|
||||
(*this)[i] = v[i];
|
||||
|
@ -66,7 +72,7 @@ namespace FasTC {
|
|||
operator MatrixBase<_T, nRows, nCols>() const {
|
||||
MatrixBase<_T, nRows, nCols> ret;
|
||||
for(int i = 0; i < Size; i++) {
|
||||
ret[i] = static_cast<_T>(this->vec[i]);
|
||||
ret[i] = static_cast<_T>(mat[i]);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
@ -87,8 +93,20 @@ namespace FasTC {
|
|||
|
||||
// Vector multiplication -- treat vectors as Nx1 matrices...
|
||||
template<typename _T>
|
||||
VectorBase<T, nCols> operator*(const VectorBase<_T, nCols> &v) {
|
||||
VectorBase<T, nCols> MultiplyVectorLeft(const VectorBase<_T, nRows> &v) const {
|
||||
VectorBase<T, nCols> result;
|
||||
for(int j = 0; j < nCols; j++) {
|
||||
result(j) = 0;
|
||||
for(int r = 0; r < nRows; r++) {
|
||||
result(j) += (*this)(r, j) * v(r);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
template<typename _T>
|
||||
VectorBase<T, nRows> MultiplyVectorRight(const VectorBase<_T, nCols> &v) const {
|
||||
VectorBase<T, nRows> result;
|
||||
for(int r = 0; r < nRows; r++) {
|
||||
result(r) = 0;
|
||||
for(int j = 0; j < nCols; j++) {
|
||||
|
@ -111,14 +129,88 @@ namespace FasTC {
|
|||
|
||||
// Double dot product
|
||||
template<typename _T>
|
||||
T DDot(const MatrixBase<_T, nRows, nCols> &m) {
|
||||
T DDot(const MatrixBase<_T, nRows, nCols> &m) const {
|
||||
T result = 0;
|
||||
for(int i = 0; i < Size; i++) {
|
||||
result += (*this)[i] * m[i];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
};
|
||||
|
||||
template<typename T, const int N, const int M>
|
||||
class VectorTraits<MatrixBase<T, N, M> > {
|
||||
public:
|
||||
static const EVectorType kVectorType = eVectorType_Matrix;
|
||||
};
|
||||
|
||||
#define REGISTER_MATRIX_TYPE(TYPE) \
|
||||
template<> \
|
||||
class VectorTraits< TYPE > { \
|
||||
public: \
|
||||
static const EVectorType kVectorType = eVectorType_Matrix; \
|
||||
}
|
||||
|
||||
#define REGISTER_ONE_TEMPLATE_MATRIX_TYPE(TYPE) \
|
||||
template<typename T> \
|
||||
class VectorTraits< TYPE <T> > { \
|
||||
public: \
|
||||
static const EVectorType kVectorType = eVectorType_Matrix; \
|
||||
}
|
||||
|
||||
// Define matrix multiplication for * operator
|
||||
template<typename TypeOne, typename TypeTwo>
|
||||
class MultSwitch<
|
||||
eVectorType_Matrix,
|
||||
eVectorType_Vector,
|
||||
TypeOne, TypeTwo> {
|
||||
private:
|
||||
const TypeOne &m_A;
|
||||
const TypeTwo &m_B;
|
||||
|
||||
public:
|
||||
typedef VectorBase<typename TypeTwo::ScalarType, TypeOne::kNumRows> ResultType;
|
||||
|
||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||
: m_A(a), m_B(b) { }
|
||||
|
||||
ResultType GetMultiplication() const { return m_A.MultiplyVectorRight(m_B); }
|
||||
};
|
||||
|
||||
template<typename TypeOne, typename TypeTwo>
|
||||
class MultSwitch<
|
||||
eVectorType_Vector,
|
||||
eVectorType_Matrix,
|
||||
TypeOne, TypeTwo> {
|
||||
private:
|
||||
const TypeOne &m_A;
|
||||
const TypeTwo &m_B;
|
||||
|
||||
public:
|
||||
typedef VectorBase<typename TypeOne::ScalarType, TypeTwo::kNumCols> ResultType;
|
||||
|
||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||
: m_A(a), m_B(b) { }
|
||||
|
||||
ResultType GetMultiplication() const { return m_B.MultiplyVectorLeft(m_A); }
|
||||
};
|
||||
|
||||
template<typename TypeOne, typename TypeTwo>
|
||||
class MultSwitch<
|
||||
eVectorType_Matrix,
|
||||
eVectorType_Matrix,
|
||||
TypeOne, TypeTwo> {
|
||||
private:
|
||||
const TypeOne &m_A;
|
||||
const TypeTwo &m_B;
|
||||
|
||||
public:
|
||||
typedef MatrixBase<typename TypeOne::ScalarType, TypeOne::kNumRows, TypeTwo::kNumCols> ResultType;
|
||||
|
||||
MultSwitch(const TypeOne &a, const TypeTwo &b)
|
||||
: m_A(a), m_B(b) { }
|
||||
|
||||
ResultType GetMultiplication() const { return m_A.MultiplyMatrix(m_B); }
|
||||
};
|
||||
|
||||
// Outer product...
|
||||
|
|
|
@ -172,8 +172,30 @@ TEST(MatrixBase, Transposition) {
|
|||
}
|
||||
|
||||
TEST(MatrixBase, VectorMultiplication) {
|
||||
// Stub
|
||||
EXPECT_EQ(0, 1);
|
||||
|
||||
FasTC::MatrixBase<int, 3, 5> a;
|
||||
a(0, 0) = -1; a(0, 1) = 2; a(0, 2) = -4; a(0, 3) = 5; a(0, 4) = 0;
|
||||
a(1, 0) = 1; a(1, 1) = 2; a(1, 2) = 4; a(1, 3) = 6; a(1, 4) = 3;
|
||||
a(2, 0) = -1; a(2, 1) = -2; a(2, 2) = -3; a(2, 3) = -4; a(2, 4) = 5;
|
||||
|
||||
FasTC::VectorBase<int, 5> v;
|
||||
for(int i = 0; i < 5; i++) v[i] = i + 1;
|
||||
|
||||
FasTC::VectorBase<int, 3> u = a * v;
|
||||
EXPECT_EQ(u[0], -1 + (2 * 2) - (4 * 3) + (5 * 4));
|
||||
EXPECT_EQ(u[1], 1 + (2 * 2) + (4 * 3) + (6 * 4) + (3 * 5));
|
||||
EXPECT_EQ(u[2], -1 + (-2 * 2) - (3 * 3) - (4 * 4) + (5 * 5));
|
||||
|
||||
/////
|
||||
|
||||
for(int i = 0; i < 3; i++) u[i] = i + 1;
|
||||
v = u * a;
|
||||
|
||||
EXPECT_EQ(v[0], -1 + (1 * 2) - (1 * 3));
|
||||
EXPECT_EQ(v[1], 2 + (2 * 2) - (2 * 3));
|
||||
EXPECT_EQ(v[2], -4 + (4 * 2) - (3 * 3));
|
||||
EXPECT_EQ(v[3], 5 + (6 * 2) - (4 * 3));
|
||||
EXPECT_EQ(v[4], 0 + (3 * 2) + (5 * 3));
|
||||
}
|
||||
|
||||
TEST(MatrixSquare, Constructors) {
|
||||
|
|
Loading…
Reference in a new issue