VCCC  2024.05
VisualCamp Common C++ library
matrix_mul_matrix.hpp
Go to the documentation of this file.
1 # /*
2 # * Created by YongGyu Lee on 2020/02/06.
3 # */
4 #
5 # ifndef VCCC_MATH_MATRIX_MATRIX_MUL_MATRIX_HPP
6 # define VCCC_MATH_MATRIX_MATRIX_MUL_MATRIX_HPP
7 #
10 
11 namespace vccc {
12 
13 namespace internal {
14 namespace math {
15 
16 template<typename LhsType, typename RhsType>
17 struct traits<MatrixMulMatrix<LhsType, RhsType>> {
18  enum : int {
19  rows = traits<LhsType>::rows,
20  cols = traits<RhsType>::cols,
21  size = rows * cols,
22  };
23 
24  enum : int {
25  option = traits<LhsType>::option | traits<RhsType>::option | Flag::kAliasUnsafe | Flag::kReferenceUnsafe
26  };
27  using value_type = typename LhsType::value_type;
28 };
29 
30 } // namespace math
31 } // namespace internal
32 
35 
36 template<typename LhsType, typename RhsType>
37 class MatrixMulMatrix : public MatrixBase<MatrixMulMatrix<LhsType, RhsType>>{
38  public:
40  using value_type = typename internal::math::traits<MatrixMulMatrix>::value_type;
41  using lhs_type = internal::math::hold_type_selector_t<LhsType>;
42  using rhs_type = internal::math::hold_type_selector_t<RhsType>;
43 
44  constexpr MatrixMulMatrix(const LhsType& lhs, const RhsType& rhs) : lhs(lhs), rhs(rhs) {};
45 
46  constexpr auto operator() (std::size_t i) const;
47  constexpr auto operator() (std::size_t i, std::size_t j) const;
48  constexpr auto operator[] (std::size_t i) const;
49 
50  private:
51  lhs_type lhs;
52  rhs_type rhs;
53 
54  using tag = internal::math::tag<int, LhsType::rows, LhsType::cols, RhsType::cols>;
55 
56 // template<int p, int q, int r>
57 // constexpr
58 // void mul(return_type& dst, tag<p, q, r>) const;
59 //
60 // constexpr
61 // void mul(return_type& dst, tag<1, 1, 1>) const;
62 //
63 // template<int p>
64 // constexpr
65 // void mul(return_type& dst, tag<1, p, 1>) const;
66 
67  // TODO: optimize 2n^2n matrix multiplication
68 // constexpr inline
69 // void mul(return_type& dst, tag<2, 2, 2>) const;
70 
71 };
72 
73 template<typename E1, typename E2,
74  std::enable_if_t<(internal::math::traits<E1>::cols == internal::math::traits<E2>::rows), int> = 0>
75 constexpr inline
77 operator*(const MatrixBase<E1>& lhs, const MatrixBase<E2>& rhs) {
78  return MatrixMulMatrix<E1, E2>(*static_cast<const E1*>(&lhs), *static_cast<const E2*>(&rhs));
79 }
80 
82 template<typename LhsType, typename RhsType>
83 constexpr auto MatrixMulMatrix<LhsType, RhsType>::operator()(std::size_t i, std::size_t j) const {
84  value_type sum(0);
85  for(int k=0; k<tag::second; ++k)
86  sum += lhs(i, k) * rhs(k, j);
87  return sum;
88 }
89 
90 template<typename LhsType, typename RhsType>
91 constexpr auto MatrixMulMatrix<LhsType, RhsType>::operator()(std::size_t i) const {
92  return operator()(i / base::cols, i % base::cols);
93 }
94 
95 template<typename LhsType, typename RhsType>
96 constexpr auto MatrixMulMatrix<LhsType, RhsType>::operator[](std::size_t i) const {
97  return operator()(i / base::cols, i % base::cols);
98 }
100 
102 
103 
104 //template<typename E1, typename E2>
105 //template<int p, int q, int r>
106 //constexpr void MatrixMulMatrix<E1, E2>::mul(return_type& dst) const {
107 // return mul(dst, tag<p, q, r>());
108 //}
109 //
110 //template<typename E1, typename E2>
111 //template<int p, int q, int r>
112 //constexpr inline void
113 //MatrixMulMatrix<E1, E2>::mul(return_type& dst, MatrixMulMatrix<E1, E2>::tag<p, q, r>) const {
114 // using T = typename MatrixMulMatrix<E1, E2>::value_type;
115 //
116 // for( int i = 0; i < m; i++ ) {
117 // for (int j = 0; j < n; j++) {
118 // T s = 0;
119 // for (int k = 0; k < l; k++)
120 // s += lhs(i, k) * rhs(k, j);
121 // dst.data[i * n + j] = s;
122 // }
123 // }
124 //}
125 //
126 //template<typename E1, typename E2, int m, int l, int n>
127 //constexpr void
128 //MatrixMulMatrix<E1, E2, m, l, n>::mul(return_type& dst, MatrixMulMatrix::tag<1, 1, 1>) const {
129 // dst.data[0] = lhs[0] * rhs[0];
130 //}
131 //
132 //template<typename LhsType, typename RhsType>>
133 //template<int p>
134 //constexpr void
135 //MatrixMulMatrix<LhsType, RhsType>::mul(MatrixMulMatrix::return_type& dst, MatrixMulMatrix::tag<1, p, 1>) const {
136 // using T = MatrixMulMatrix<E1, E2>::value_type;
137 //
138 // T sum = 0;
139 // for(int i=0; i<p; ++i)
140 // sum += lhs[i] * rhs[i];
141 // dst.data[0] = sum;
142 //}
143 
144 } // namespace vccc
145 
146 # endif // VCCC_MATH_MATRIX_MATRIX_MUL_MATRIX_HPP
Definition: matrix_base.hpp:20
Definition: matrix_mul_matrix.hpp:37
internal::math::hold_type_selector_t< RhsType > rhs_type
Definition: matrix_mul_matrix.hpp:42
internal::math::hold_type_selector_t< LhsType > lhs_type
Definition: matrix_mul_matrix.hpp:41
constexpr auto operator[](std::size_t i) const
typename internal::math::traits< MatrixMulMatrix >::value_type value_type
Definition: matrix_mul_matrix.hpp:40
constexpr auto operator()(std::size_t i) const
constexpr MatrixMulMatrix(const LhsType &lhs, const RhsType &rhs)
Definition: matrix_mul_matrix.hpp:44
constexpr auto sum(InputIterator first, InputIterator last)
sum of iterator [first, last)
Definition: sum.hpp:57
Definition: directory.h:12
constexpr VCCC_INLINE_OR_STATIC detail::element_niebloid< 1 > second
Definition: key_value.hpp:36
constexpr auto size(const C &c) -> decltype(c.size())
Definition: size.hpp:16