xenia/third_party/crunch/crnlib/crn_matrix.h

566 lines
15 KiB
C++

// File: crn_matrix.h
// See Copyright Notice and license at the end of inc/crnlib.h
#pragma once
#include "crn_vec.h"
namespace crnlib
{
template<class X, class Y, class Z> Z& matrix_mul_helper(Z& result, const X& lhs, const Y& rhs)
{
CRNLIB_ASSUME(Z::num_rows == X::num_rows);
CRNLIB_ASSUME(Z::num_cols == Y::num_cols);
CRNLIB_ASSUME(X::num_cols == Y::num_rows);
CRNLIB_ASSERT((&result != &lhs) && (&result != &rhs));
for (int r = 0; r < X::num_rows; r++)
for (int c = 0; c < Y::num_cols; c++)
{
typename Z::scalar_type s = lhs(r, 0) * rhs(0, c);
for (uint i = 1; i < X::num_cols; i++)
s += lhs(r, i) * rhs(i, c);
result(r, c) = s;
}
return result;
}
template<class X, class Y, class Z> Z& matrix_mul_helper_transpose_lhs(Z& result, const X& lhs, const Y& rhs)
{
CRNLIB_ASSUME(Z::num_rows == X::num_cols);
CRNLIB_ASSUME(Z::num_cols == Y::num_cols);
CRNLIB_ASSUME(X::num_rows == Y::num_rows);
for (int r = 0; r < X::num_cols; r++)
for (int c = 0; c < Y::num_cols; c++)
{
typename Z::scalar_type s = lhs(0, r) * rhs(0, c);
for (uint i = 1; i < X::num_rows; i++)
s += lhs(i, r) * rhs(i, c);
result(r, c) = s;
}
return result;
}
template<class X, class Y, class Z> Z& matrix_mul_helper_transpose_rhs(Z& result, const X& lhs, const Y& rhs)
{
CRNLIB_ASSUME(Z::num_rows == X::num_rows);
CRNLIB_ASSUME(Z::num_cols == Y::num_rows);
CRNLIB_ASSUME(X::num_cols == Y::num_cols);
for (int r = 0; r < X::num_rows; r++)
for (int c = 0; c < Y::num_rows; c++)
{
typename Z::scalar_type s = lhs(r, 0) * rhs(c, 0);
for (uint i = 1; i < X::num_cols; i++)
s += lhs(r, i) * rhs(c, i);
result(r, c) = s;
}
return result;
}
template<uint R, uint C, typename T>
class matrix
{
public:
typedef T scalar_type;
enum { num_rows = R, num_cols = C };
typedef vec<R, T> col_vec;
typedef vec<(R > 1) ? (R - 1) : 0, T> subcol_vec;
typedef vec<C, T> row_vec;
typedef vec<(C > 1) ? (C - 1) : 0, T> subrow_vec;
inline matrix() { }
inline matrix(eClear) { clear(); }
inline matrix(const T* p) { set(p); }
inline matrix(const matrix& other)
{
for (uint i = 0; i < R; i++)
m_rows[i] = other.m_rows[i];
}
inline matrix& operator= (const matrix& rhs)
{
if (this != &rhs)
for (uint i = 0; i < R; i++)
m_rows[i] = rhs.m_rows[i];
return *this;
}
inline matrix(T val00, T val01,
T val10, T val11)
{
set(val00, val01, val10, val11);
}
inline matrix(T val00, T val01, T val02,
T val10, T val11, T val12,
T val20, T val21, T val22)
{
set(val00, val01, val02, val10, val11, val12, val20, val21, val22);
}
inline matrix(T val00, T val01, T val02, T val03,
T val10, T val11, T val12, T val13,
T val20, T val21, T val22, T val23,
T val30, T val31, T val32, T val33)
{
set(val00, val01, val02, val03, val10, val11, val12, val13, val20, val21, val22, val23, val30, val31, val32, val33);
}
inline void set(const float* p)
{
for (uint i = 0; i < R; i++)
{
m_rows[i].set(p);
p += C;
}
}
inline void set(T val00, T val01,
T val10, T val11)
{
m_rows[0].set(val00, val01);
if (R >= 2)
{
m_rows[1].set(val10, val11);
for (uint i = 2; i < R; i++)
m_rows[i].clear();
}
}
inline void set(T val00, T val01, T val02,
T val10, T val11, T val12,
T val20, T val21, T val22)
{
m_rows[0].set(val00, val01, val02);
if (R >= 2)
{
m_rows[1].set(val10, val11, val12);
if (R >= 3)
{
m_rows[2].set(val20, val21, val22);
for (uint i = 3; i < R; i++)
m_rows[i].clear();
}
}
}
inline void set(T val00, T val01, T val02, T val03,
T val10, T val11, T val12, T val13,
T val20, T val21, T val22, T val23,
T val30, T val31, T val32, T val33)
{
m_rows[0].set(val00, val01, val02, val03);
if (R >= 2)
{
m_rows[1].set(val10, val11, val12, val13);
if (R >= 3)
{
m_rows[2].set(val20, val21, val22, val23);
if (R >= 4)
{
m_rows[3].set(val30, val31, val32, val33);
for (uint i = 4; i < R; i++)
m_rows[i].clear();
}
}
}
}
inline T operator() (uint r, uint c) const
{
CRNLIB_ASSERT((r < R) && (c < C));
return m_rows[r][c];
}
inline T& operator() (uint r, uint c)
{
CRNLIB_ASSERT((r < R) && (c < C));
return m_rows[r][c];
}
inline const row_vec& operator[] (uint r) const
{
CRNLIB_ASSERT(r < R);
return m_rows[r];
}
inline row_vec& operator[] (uint r)
{
CRNLIB_ASSERT(r < R);
return m_rows[r];
}
inline const row_vec& get_row (uint r) const { return (*this)[r]; }
inline row_vec& get_row (uint r) { return (*this)[r]; }
inline col_vec get_col(uint c) const
{
CRNLIB_ASSERT(c < C);
col_vec result;
for (uint i = 0; i < R; i++)
result[i] = m_rows[i][c];
return result;
}
inline void set_col(uint c, const col_vec& col)
{
CRNLIB_ASSERT(c < C);
for (uint i = 0; i < R; i++)
m_rows[i][c] = col[i];
}
inline void set_col(uint c, const subcol_vec& col)
{
CRNLIB_ASSERT(c < C);
for (uint i = 0; i < (R - 1); i++)
m_rows[i][c] = col[i];
m_rows[R - 1][c] = 0.0f;
}
inline const row_vec& get_translate() const
{
return m_rows[R - 1];
}
inline matrix& set_translate(const row_vec& r)
{
m_rows[R - 1] = r;
return *this;
}
inline matrix& set_translate(const subrow_vec& r)
{
m_rows[R - 1] = row_vec(r).as_point();
return *this;
}
inline const T* get_ptr() const { return reinterpret_cast<const T*>(&m_rows[0]); }
inline T* get_ptr() { return reinterpret_cast< T*>(&m_rows[0]); }
inline matrix& operator+= (const matrix& other)
{
for (uint i = 0; i < R; i++)
m_rows[i] += other.m_rows[i];
return *this;
}
inline matrix& operator-= (const matrix& other)
{
for (uint i = 0; i < R; i++)
m_rows[i] -= other.m_rows[i];
return *this;
}
inline matrix& operator*= (T val)
{
for (uint i = 0; i < R; i++)
m_rows[i] *= val;
return *this;
}
inline matrix& operator/= (T val)
{
for (uint i = 0; i < R; i++)
m_rows[i] /= val;
return *this;
}
inline matrix& operator*= (const matrix& other)
{
matrix result;
matrix_mul_helper(result, *this, other);
*this = result;
return *this;
}
friend inline matrix operator+ (const matrix& lhs, const matrix& rhs)
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = lhs.m_rows[i] + rhs.m_rows[i];
return result;
}
friend inline matrix operator- (const matrix& lhs, const matrix& rhs)
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = lhs.m_rows[i] - rhs.m_rows[i];
return result;
}
friend inline matrix operator* (const matrix& lhs, T val)
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = lhs.m_rows[i] * val;
return result;
}
friend inline matrix operator/ (const matrix& lhs, T val)
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = lhs.m_rows[i] / val;
return result;
}
friend inline matrix operator* (T val, const matrix& rhs)
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = val * rhs.m_rows[i];
return result;
}
friend inline matrix operator* (const matrix& lhs, const matrix& rhs)
{
matrix result;
return matrix_mul_helper(result, lhs, rhs);
}
friend inline row_vec operator* (const col_vec& a, const matrix& b)
{
return transform(a, b);
}
inline matrix operator+ () const
{
return *this;
}
inline matrix operator- () const
{
matrix result;
for (uint i = 0; i < R; i++)
result[i] = -m_rows[i];
return result;
}
inline void clear(void)
{
for (uint i = 0; i < R; i++)
m_rows[i].clear();
}
inline void set_zero_matrix()
{
clear();
}
inline void set_identity_matrix()
{
for (uint i = 0; i < R; i++)
{
m_rows[i].clear();
m_rows[i][i] = 1.0f;
}
}
inline matrix& set_scale_matrix(float s)
{
clear();
for (int i = 0; i < (R - 1); i++)
m_rows[i][i] = s;
m_rows[R - 1][C - 1] = 1.0f;
return *this;
}
inline matrix& set_scale_matrix(const row_vec& s)
{
clear();
for (uint i = 0; i < R; i++)
m_rows[i][i] = s[i];
return *this;
}
inline matrix& set_translate_matrix(const row_vec& s)
{
set_identity_matrix();
set_translate(s);
return *this;
}
inline matrix& set_translate_matrix(float x, float y)
{
set_identity_matrix();
set_translate(row_vec(x, y).as_point());
return *this;
}
inline matrix& set_translate_matrix(float x, float y, float z)
{
set_identity_matrix();
set_translate(row_vec(x, y, z).as_point());
return *this;
}
inline matrix get_transposed(void) const
{
matrix result;
for (uint i = 0; i < R; i++)
for (uint j = 0; j < C; j++)
result.m_rows[i][j] = m_rows[j][i];
return result;
}
inline matrix& transpose_in_place(void)
{
matrix result;
for (uint i = 0; i < R; i++)
for (uint j = 0; j < C; j++)
result.m_rows[i][j] = m_rows[j][i];
*this = result;
return *this;
}
// This method transforms a column vec by a matrix (D3D-style).
static inline row_vec transform(const col_vec& a, const matrix& b)
{
row_vec result(b[0] * a[0]);
for (uint r = 1; r < R; r++)
result += b[r] * a[r];
return result;
}
// This method transforms a column vec by a matrix. Last component of vec is assumed to be 1.
static inline row_vec transform_point(const col_vec& a, const matrix& b)
{
row_vec result(0);
for (int r = 0; r < (R - 1); r++)
result += b[r] * a[r];
result += b[R - 1];
return result;
}
// This method transforms a column vec by a matrix. Last component of vec is assumed to be 0.
static inline row_vec transform_vector(const col_vec& a, const matrix& b)
{
row_vec result(0);
for (int r = 0; r < (R - 1); r++)
result += b[r] * a[r];
return result;
}
static inline subcol_vec transform_point(const subcol_vec& a, const matrix& b)
{
subcol_vec result(0);
for (int r = 0; r < R; r++)
{
const T s = (r < subcol_vec::num_elements) ? a[r] : 1.0f;
for (int c = 0; c < (C - 1); c++)
result[c] += b[r][c] * s;
}
return result;
}
static inline subcol_vec transform_vector(const subcol_vec& a, const matrix& b)
{
subcol_vec result(0);
for (int r = 0; r < (R - 1); r++)
{
const T s = a[r];
for (int c = 0; c < (C - 1); c++)
result[c] += b[r][c] * s;
}
return result;
}
// This method transforms a column vec by the transpose of a matrix.
static inline col_vec transform_transposed(const matrix& b, const col_vec& a)
{
CRNLIB_ASSUME(R == C);
col_vec result;
for (uint r = 0; r < R; r++)
result[r] = b[r] * a;
return result;
}
// This method transforms a column vec by the transpose of a matrix. Last component of vec is assumed to be 0.
static inline col_vec transform_vector_transposed(const matrix& b, const col_vec& a)
{
CRNLIB_ASSUME(R == C);
col_vec result;
for (uint r = 0; r < R; r++)
{
T s = 0;
for (uint c = 0; c < (C - 1); c++)
s += b[r][c] * a[c];
result[r] = s;
}
return result;
}
// This method transforms a matrix by a row vector (OGL style).
static inline col_vec transform(const matrix& b, const row_vec& a)
{
col_vec result;
for (int r = 0; r < R; r++)
result[r] = b[r] * a;
return result;
}
static inline matrix& multiply(matrix& result, const matrix& lhs, const matrix& rhs)
{
return matrix_mul_helper(result, lhs, rhs);
}
static inline matrix make_scale_matrix(float s)
{
return matrix().set_scale_matrix(s);
}
static inline matrix make_scale_matrix(const row_vec& s)
{
return matrix().set_scale_matrix(s);
}
static inline matrix make_scale_matrix(float x, float y)
{
CRNLIB_ASSUME(R >= 3 && C >= 3);
matrix result;
result.clear();
result.m_rows[0][0] = x;
result.m_rows[1][1] = y;
result.m_rows[2][2] = 1.0f;
return result;
}
static inline matrix make_scale_matrix(float x, float y, float z)
{
CRNLIB_ASSUME(R >= 4 && C >= 4);
matrix result;
result.clear();
result.m_rows[0][0] = x;
result.m_rows[1][1] = y;
result.m_rows[2][2] = z;
result.m_rows[3][3] = 1.0f;
return result;
}
private:
row_vec m_rows[R];
};
typedef matrix<2, 2, float> matrix22F;
typedef matrix<2, 2, double> matrix22D;
typedef matrix<3, 3, float> matrix33F;
typedef matrix<3, 3, double> matrix33D;
typedef matrix<4, 4, float> matrix44F;
typedef matrix<4, 4, double> matrix44D;
typedef matrix<8, 8, float> matrix88F;
} // namespace crnlib