10 #ifndef EIGEN_SPARSEDENSEPRODUCT_H 11 #define EIGEN_SPARSEDENSEPRODUCT_H 17 template <>
struct product_promote_storage_type<Sparse,Dense, OuterProduct> {
typedef Sparse ret; };
18 template <>
struct product_promote_storage_type<Dense,Sparse, OuterProduct> {
typedef Sparse ret; };
20 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
23 bool ColPerCol = ((DenseRhsType::Flags&
RowMajorBit)==0) || DenseRhsType::ColsAtCompileTime==1>
24 struct sparse_time_dense_product_impl;
26 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
27 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
RowMajor, true>
29 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
30 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
31 typedef typename internal::remove_all<DenseResType>::type Res;
32 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
33 typedef evaluator<Lhs> LhsEval;
34 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const typename Res::Scalar& alpha)
38 Index n = lhs.outerSize();
39 #ifdef EIGEN_HAS_OPENMP 40 Eigen::initParallel();
41 Index threads = Eigen::nbThreads();
44 for(Index c=0; c<rhs.cols(); ++c)
46 #ifdef EIGEN_HAS_OPENMP 49 if(threads>1 && lhsEval.nonZerosEstimate() > 20000)
51 #pragma omp parallel for schedule(static) num_threads(threads) 52 for(Index i=0; i<n; ++i)
53 processRow(lhsEval,rhs,res,alpha,i,c);
58 for(Index i=0; i<n; ++i)
59 processRow(lhsEval,rhs,res,alpha,i,c);
64 static void processRow(
const LhsEval& lhsEval,
const DenseRhsType& rhs, DenseResType& res,
const typename Res::Scalar& alpha, Index i, Index col)
66 typename Res::Scalar tmp(0);
67 for(LhsInnerIterator it(lhsEval,i); it ;++it)
68 tmp += it.value() * rhs.coeff(it.index(),col);
69 res.coeffRef(i,col) += alpha * tmp;
75 template<
typename T1,
typename T2>
76 struct scalar_product_traits<T1, Ref<T2> >
81 typedef typename CwiseUnaryOp<scalar_multiple2_op<T1, typename T2::Scalar>, T2>::PlainObject ReturnType;
83 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
84 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType,
ColMajor, true>
86 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
87 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
88 typedef typename internal::remove_all<DenseResType>::type Res;
89 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
90 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const AlphaType& alpha)
92 evaluator<Lhs> lhsEval(lhs);
93 for(Index c=0; c<rhs.cols(); ++c)
95 for(Index j=0; j<lhs.outerSize(); ++j)
98 typename internal::scalar_product_traits<AlphaType, typename Rhs::Scalar>::ReturnType rhs_j(alpha * rhs.coeff(j,c));
99 for(LhsInnerIterator it(lhsEval,j); it ;++it)
100 res.coeffRef(it.index(),c) += it.value() * rhs_j;
106 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
107 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
RowMajor, false>
109 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
110 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
111 typedef typename internal::remove_all<DenseResType>::type Res;
112 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
113 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const typename Res::Scalar& alpha)
115 evaluator<Lhs> lhsEval(lhs);
116 for(Index j=0; j<lhs.outerSize(); ++j)
118 typename Res::RowXpr res_j(res.row(j));
119 for(LhsInnerIterator it(lhsEval,j); it ;++it)
120 res_j += (alpha*it.value()) * rhs.row(it.index());
125 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType>
126 struct sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, typename DenseResType::Scalar,
ColMajor, false>
128 typedef typename internal::remove_all<SparseLhsType>::type Lhs;
129 typedef typename internal::remove_all<DenseRhsType>::type Rhs;
130 typedef typename internal::remove_all<DenseResType>::type Res;
131 typedef typename evaluator<Lhs>::InnerIterator LhsInnerIterator;
132 static void run(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const typename Res::Scalar& alpha)
134 evaluator<Lhs> lhsEval(lhs);
135 for(Index j=0; j<lhs.outerSize(); ++j)
137 typename Rhs::ConstRowXpr rhs_j(rhs.row(j));
138 for(LhsInnerIterator it(lhsEval,j); it ;++it)
139 res.row(it.index()) += (alpha*it.value()) * rhs_j;
144 template<
typename SparseLhsType,
typename DenseRhsType,
typename DenseResType,
typename AlphaType>
145 inline void sparse_time_dense_product(
const SparseLhsType& lhs,
const DenseRhsType& rhs, DenseResType& res,
const AlphaType& alpha)
147 sparse_time_dense_product_impl<SparseLhsType,DenseRhsType,DenseResType, AlphaType>::run(lhs, rhs, res, alpha);
154 template<
typename Lhs,
typename Rhs,
int ProductType>
155 struct generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
156 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,SparseShape,DenseShape,ProductType> >
158 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
160 template<
typename Dest>
161 static void scaleAndAddTo(Dest& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
163 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? 1 : Rhs::ColsAtCompileTime>::type LhsNested;
164 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==0) ? 1 : Dynamic>::type RhsNested;
165 LhsNested lhsNested(lhs);
166 RhsNested rhsNested(rhs);
167 internal::sparse_time_dense_product(lhsNested, rhsNested, dst, alpha);
171 template<
typename Lhs,
typename Rhs,
int ProductType>
172 struct generic_product_impl<Lhs, Rhs, SparseTriangularShape, DenseShape, ProductType>
173 : generic_product_impl<Lhs, Rhs, SparseShape, DenseShape, ProductType>
176 template<
typename Lhs,
typename Rhs,
int ProductType>
177 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
178 : generic_product_impl_base<Lhs,Rhs,generic_product_impl<Lhs,Rhs,DenseShape,SparseShape,ProductType> >
180 typedef typename Product<Lhs,Rhs>::Scalar Scalar;
182 template<
typename Dst>
183 static void scaleAndAddTo(Dst& dst,
const Lhs& lhs,
const Rhs& rhs,
const Scalar& alpha)
185 typedef typename nested_eval<Lhs,((Rhs::Flags&RowMajorBit)==0) ? Dynamic : 1>::type LhsNested;
186 typedef typename nested_eval<Rhs,((Lhs::Flags&RowMajorBit)==RowMajorBit) ? 1 : Lhs::RowsAtCompileTime>::type RhsNested;
187 LhsNested lhsNested(lhs);
188 RhsNested rhsNested(rhs);
191 Transpose<Dst> dstT(dst);
192 internal::sparse_time_dense_product(rhsNested.transpose(), lhsNested.transpose(), dstT, alpha);
196 template<
typename Lhs,
typename Rhs,
int ProductType>
197 struct generic_product_impl<Lhs, Rhs, DenseShape, SparseTriangularShape, ProductType>
198 : generic_product_impl<Lhs, Rhs, DenseShape, SparseShape, ProductType>
201 template<
typename LhsT,
typename RhsT,
bool NeedToTranspose>
202 struct sparse_dense_outer_product_evaluator
205 typedef typename conditional<NeedToTranspose,RhsT,LhsT>::type Lhs1;
206 typedef typename conditional<NeedToTranspose,LhsT,RhsT>::type ActualRhs;
207 typedef Product<LhsT,RhsT,DefaultProduct> ProdXprType;
211 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
212 Lhs1, SparseView<Lhs1> >::type ActualLhs;
213 typedef typename conditional<is_same<typename internal::traits<Lhs1>::StorageKind,Sparse>::value,
214 Lhs1
const&, SparseView<Lhs1> >::type LhsArg;
216 typedef evaluator<ActualLhs> LhsEval;
217 typedef evaluator<ActualRhs> RhsEval;
218 typedef typename evaluator<ActualLhs>::InnerIterator LhsIterator;
219 typedef typename ProdXprType::Scalar Scalar;
224 CoeffReadCost = HugeCost
227 class InnerIterator :
public LhsIterator
230 InnerIterator(
const sparse_dense_outer_product_evaluator &xprEval, Index outer)
231 : LhsIterator(xprEval.m_lhsXprImpl, 0),
234 m_factor(get(xprEval.m_rhsXprImpl, outer, typename
internal::traits<ActualRhs>::StorageKind() ))
237 EIGEN_STRONG_INLINE Index outer()
const {
return m_outer; }
238 EIGEN_STRONG_INLINE Index row()
const {
return NeedToTranspose ? m_outer : LhsIterator::index(); }
239 EIGEN_STRONG_INLINE Index col()
const {
return NeedToTranspose ? LhsIterator::index() : m_outer; }
241 EIGEN_STRONG_INLINE Scalar value()
const {
return LhsIterator::value() * m_factor; }
242 EIGEN_STRONG_INLINE
operator bool()
const {
return LhsIterator::operator bool() && (!m_empty); }
245 Scalar
get(
const RhsEval &rhs, Index outer, Dense = Dense())
const 247 return rhs.coeff(outer);
250 Scalar
get(
const RhsEval &rhs, Index outer, Sparse = Sparse())
252 typename RhsEval::InnerIterator it(rhs, outer);
253 if (it && it.index()==0 && it.value()!=Scalar(0))
264 sparse_dense_outer_product_evaluator(
const Lhs1 &lhs,
const ActualRhs &rhs)
265 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
267 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
271 sparse_dense_outer_product_evaluator(
const ActualRhs &rhs,
const Lhs1 &lhs)
272 : m_lhs(lhs), m_lhsXprImpl(m_lhs), m_rhsXprImpl(rhs)
274 EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
279 evaluator<ActualLhs> m_lhsXprImpl;
280 evaluator<ActualRhs> m_rhsXprImpl;
284 template<
typename Lhs,
typename Rhs>
285 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, SparseShape, DenseShape>
286 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor>
288 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Lhs::IsRowMajor> Base;
290 typedef Product<Lhs, Rhs> XprType;
291 typedef typename XprType::PlainObject PlainObject;
293 explicit product_evaluator(
const XprType& xpr)
294 : Base(xpr.lhs(), xpr.rhs())
299 template<
typename Lhs,
typename Rhs>
300 struct product_evaluator<Product<Lhs, Rhs, DefaultProduct>, OuterProduct, DenseShape, SparseShape>
301 : sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor>
303 typedef sparse_dense_outer_product_evaluator<Lhs,Rhs, Rhs::IsRowMajor> Base;
305 typedef Product<Lhs, Rhs> XprType;
306 typedef typename XprType::PlainObject PlainObject;
308 explicit product_evaluator(
const XprType& xpr)
309 : Base(xpr.lhs(), xpr.rhs())
318 #endif // EIGEN_SPARSEDENSEPRODUCT_H
const unsigned int RowMajorBit
Definition: Constants.h:61
Definition: Constants.h:320
Definition: Constants.h:322
Definition: Eigen_Colamd.h:54