1
// This file is part of Eigen, a lightweight C++ template library
4
// Copyright (C) 2008-2010 Gael Guennebaud <gael.guennebaud@inria.fr>
6
// Eigen is free software; you can redistribute it and/or
7
// modify it under the terms of the GNU Lesser General Public
8
// License as published by the Free Software Foundation; either
9
// version 3 of the License, or (at your option) any later version.
11
// Alternatively, you can redistribute it and/or
12
// modify it under the terms of the GNU General Public License as
13
// published by the Free Software Foundation; either version 2 of
14
// the License, or (at your option) any later version.
16
// Eigen is distributed in the hope that it will be useful, but WITHOUT ANY
17
// WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
18
// FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License or the
19
// GNU General Public License for more details.
21
// You should have received a copy of the GNU Lesser General Public
22
// License and a copy of the GNU General Public License along with
23
// Eigen. If not, see <http://www.gnu.org/licenses/>.
25
#ifndef EIGEN_SPARSESPARSEPRODUCT_H
26
#define EIGEN_SPARSESPARSEPRODUCT_H
30
template<typename Lhs, typename Rhs, typename ResultType>
31
static void sparse_product_impl2(const Lhs& lhs, const Rhs& rhs, ResultType& res)
33
typedef typename remove_all<Lhs>::type::Scalar Scalar;
34
typedef typename remove_all<Lhs>::type::Index Index;
36
// make sure to call innerSize/outerSize since we fake the storage order.
37
Index rows = lhs.innerSize();
38
Index cols = rhs.outerSize();
39
eigen_assert(lhs.outerSize() == rhs.innerSize());
41
std::vector<bool> mask(rows,false);
42
Matrix<Scalar,Dynamic,1> values(rows);
43
Matrix<Index,Dynamic,1> indices(rows);
45
// estimate the number of non zero entries
46
float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
47
float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
48
float ratioRes = (std::min)(ratioLhs * avgNnzPerRhsColumn, 1.f);
50
// int t200 = rows/(log2(200)*1.39);
51
// int t = (rows*100)/139;
53
res.resize(rows, cols);
54
res.reserve(Index(ratioRes*rows*cols));
55
// we compute each column of the result, one after the other
56
for (Index j=0; j<cols; ++j)
61
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
63
Scalar y = rhsIt.value();
64
Index k = rhsIt.index();
65
for (typename Lhs::InnerIterator lhsIt(lhs, k); lhsIt; ++lhsIt)
67
Index i = lhsIt.index();
68
Scalar x = lhsIt.value();
80
// FIXME reserve nnz non zeros
81
// FIXME implement fast sort algorithms for very small nnz
82
// if the result is sparse enough => use a quick sort
83
// otherwise => loop through the entire vector
84
// In order to avoid to perform an expensive log2 when the
85
// result is clearly very sparse we use a linear bound up to 200.
86
// if((nnz<200 && nnz<t200) || nnz * log2(nnz) < t)
88
// if(nnz>1) std::sort(indices.data(),indices.data()+nnz);
89
// for(int k=0; k<nnz; ++k)
91
// int i = indices[k];
92
// res.insertBackNoCheck(j,i) = values[i];
99
// for(int i=0; i<rows; ++i)
104
// res.insertBackNoCheck(j,i) = values[i];
113
// perform a pseudo in-place sparse * sparse product assuming all matrices are col major
114
template<typename Lhs, typename Rhs, typename ResultType>
115
static void sparse_product_impl(const Lhs& lhs, const Rhs& rhs, ResultType& res)
117
// return sparse_product_impl2(lhs,rhs,res);
119
typedef typename remove_all<Lhs>::type::Scalar Scalar;
120
typedef typename remove_all<Lhs>::type::Index Index;
122
// make sure to call innerSize/outerSize since we fake the storage order.
123
Index rows = lhs.innerSize();
124
Index cols = rhs.outerSize();
125
//int size = lhs.outerSize();
126
eigen_assert(lhs.outerSize() == rhs.innerSize());
128
// allocate a temporary buffer
129
AmbiVector<Scalar,Index> tempVector(rows);
131
// estimate the number of non zero entries
132
float ratioLhs = float(lhs.nonZeros())/(float(lhs.rows())*float(lhs.cols()));
133
float avgNnzPerRhsColumn = float(rhs.nonZeros())/float(cols);
134
float ratioRes = (std::min)(ratioLhs * avgNnzPerRhsColumn, 1.f);
136
// mimics a resizeByInnerOuter:
137
if(ResultType::IsRowMajor)
138
res.resize(cols, rows);
140
res.resize(rows, cols);
142
res.reserve(Index(ratioRes*rows*cols));
143
for (Index j=0; j<cols; ++j)
145
// let's do a more accurate determination of the nnz ratio for the current column j of res
146
//float ratioColRes = (std::min)(ratioLhs * rhs.innerNonZeros(j), 1.f);
147
// FIXME find a nice way to get the number of nonzeros of a sub matrix (here an inner vector)
148
float ratioColRes = ratioRes;
149
tempVector.init(ratioColRes);
150
tempVector.setZero();
151
for (typename Rhs::InnerIterator rhsIt(rhs, j); rhsIt; ++rhsIt)
153
// FIXME should be written like this: tmp += rhsIt.value() * lhs.col(rhsIt.index())
154
tempVector.restart();
155
Scalar x = rhsIt.value();
156
for (typename Lhs::InnerIterator lhsIt(lhs, rhsIt.index()); lhsIt; ++lhsIt)
158
tempVector.coeffRef(lhsIt.index()) += lhsIt.value() * x;
162
for (typename AmbiVector<Scalar,Index>::Iterator it(tempVector); it; ++it)
163
res.insertBackByOuterInner(j,it.index()) = it.value();
168
template<typename Lhs, typename Rhs, typename ResultType,
169
int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
170
int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
171
int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
172
struct sparse_product_selector;
174
template<typename Lhs, typename Rhs, typename ResultType>
175
struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
177
typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
179
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
181
// std::cerr << __LINE__ << "\n";
182
typename remove_all<ResultType>::type _res(res.rows(), res.cols());
183
sparse_product_impl<Lhs,Rhs,ResultType>(lhs, rhs, _res);
188
template<typename Lhs, typename Rhs, typename ResultType>
189
struct sparse_product_selector<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
191
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
193
// std::cerr << __LINE__ << "\n";
194
// we need a col-major matrix to hold the result
195
typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
196
SparseTemporaryType _res(res.rows(), res.cols());
197
sparse_product_impl<Lhs,Rhs,SparseTemporaryType>(lhs, rhs, _res);
202
template<typename Lhs, typename Rhs, typename ResultType>
203
struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
205
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
207
// std::cerr << __LINE__ << "\n";
208
// let's transpose the product to get a column x column product
209
typename remove_all<ResultType>::type _res(res.rows(), res.cols());
210
sparse_product_impl<Rhs,Lhs,ResultType>(rhs, lhs, _res);
215
template<typename Lhs, typename Rhs, typename ResultType>
216
struct sparse_product_selector<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
218
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
220
// std::cerr << "here...\n";
221
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
222
ColMajorMatrix colLhs(lhs);
223
ColMajorMatrix colRhs(rhs);
224
// std::cerr << "more...\n";
225
sparse_product_impl<ColMajorMatrix,ColMajorMatrix,ResultType>(colLhs, colRhs, res);
226
// std::cerr << "OK.\n";
228
// let's transpose the product to get a column x column product
230
// typedef SparseMatrix<typename ResultType::Scalar> SparseTemporaryType;
231
// SparseTemporaryType _res(res.cols(), res.rows());
232
// sparse_product_impl<Rhs,Lhs,SparseTemporaryType>(rhs, lhs, _res);
233
// res = _res.transpose();
237
// NOTE the 2 others cases (col row *) must never occur since they are caught
238
// by ProductReturnType which transforms it to (col col *) by evaluating rhs.
240
} // end namespace internal
242
// sparse = sparse * sparse
243
template<typename Derived>
244
template<typename Lhs, typename Rhs>
245
inline Derived& SparseMatrixBase<Derived>::operator=(const SparseSparseProduct<Lhs,Rhs>& product)
247
// std::cerr << "there..." << typeid(Lhs).name() << " " << typeid(Lhs).name() << " " << (Derived::Flags&&RowMajorBit) << "\n";
248
internal::sparse_product_selector<
249
typename internal::remove_all<Lhs>::type,
250
typename internal::remove_all<Rhs>::type,
251
Derived>::run(product.lhs(),product.rhs(),derived());
257
template<typename Lhs, typename Rhs, typename ResultType,
258
int LhsStorageOrder = traits<Lhs>::Flags&RowMajorBit,
259
int RhsStorageOrder = traits<Rhs>::Flags&RowMajorBit,
260
int ResStorageOrder = traits<ResultType>::Flags&RowMajorBit>
261
struct sparse_product_selector2;
263
template<typename Lhs, typename Rhs, typename ResultType>
264
struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,ColMajor>
266
typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
268
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
270
sparse_product_impl2<Lhs,Rhs,ResultType>(lhs, rhs, res);
274
template<typename Lhs, typename Rhs, typename ResultType>
275
struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,ColMajor>
277
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
279
// prevent warnings until the code is fixed
280
EIGEN_UNUSED_VARIABLE(lhs);
281
EIGEN_UNUSED_VARIABLE(rhs);
282
EIGEN_UNUSED_VARIABLE(res);
284
// typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
285
// RowMajorMatrix rhsRow = rhs;
286
// RowMajorMatrix resRow(res.rows(), res.cols());
287
// sparse_product_impl2<RowMajorMatrix,Lhs,RowMajorMatrix>(rhsRow, lhs, resRow);
292
template<typename Lhs, typename Rhs, typename ResultType>
293
struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,ColMajor>
295
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
297
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
298
RowMajorMatrix lhsRow = lhs;
299
RowMajorMatrix resRow(res.rows(), res.cols());
300
sparse_product_impl2<Rhs,RowMajorMatrix,RowMajorMatrix>(rhs, lhsRow, resRow);
305
template<typename Lhs, typename Rhs, typename ResultType>
306
struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,ColMajor>
308
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
310
typedef SparseMatrix<typename ResultType::Scalar,RowMajor> RowMajorMatrix;
311
RowMajorMatrix resRow(res.rows(), res.cols());
312
sparse_product_impl2<Rhs,Lhs,RowMajorMatrix>(rhs, lhs, resRow);
318
template<typename Lhs, typename Rhs, typename ResultType>
319
struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,ColMajor,RowMajor>
321
typedef typename traits<typename remove_all<Lhs>::type>::Scalar Scalar;
323
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
325
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
326
ColMajorMatrix resCol(res.rows(), res.cols());
327
sparse_product_impl2<Lhs,Rhs,ColMajorMatrix>(lhs, rhs, resCol);
332
template<typename Lhs, typename Rhs, typename ResultType>
333
struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,ColMajor,RowMajor>
335
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
337
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
338
ColMajorMatrix lhsCol = lhs;
339
ColMajorMatrix resCol(res.rows(), res.cols());
340
sparse_product_impl2<ColMajorMatrix,Rhs,ColMajorMatrix>(lhsCol, rhs, resCol);
345
template<typename Lhs, typename Rhs, typename ResultType>
346
struct sparse_product_selector2<Lhs,Rhs,ResultType,ColMajor,RowMajor,RowMajor>
348
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
350
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
351
ColMajorMatrix rhsCol = rhs;
352
ColMajorMatrix resCol(res.rows(), res.cols());
353
sparse_product_impl2<Lhs,ColMajorMatrix,ColMajorMatrix>(lhs, rhsCol, resCol);
358
template<typename Lhs, typename Rhs, typename ResultType>
359
struct sparse_product_selector2<Lhs,Rhs,ResultType,RowMajor,RowMajor,RowMajor>
361
static void run(const Lhs& lhs, const Rhs& rhs, ResultType& res)
363
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
364
// ColMajorMatrix lhsTr(lhs);
365
// ColMajorMatrix rhsTr(rhs);
366
// ColMajorMatrix aux(res.rows(), res.cols());
367
// sparse_product_impl2<Rhs,Lhs,ColMajorMatrix>(rhs, lhs, aux);
368
// // ColMajorMatrix aux2 = aux.transpose();
370
typedef SparseMatrix<typename ResultType::Scalar,ColMajor> ColMajorMatrix;
371
ColMajorMatrix lhsCol(lhs);
372
ColMajorMatrix rhsCol(rhs);
373
ColMajorMatrix resCol(res.rows(), res.cols());
374
sparse_product_impl2<ColMajorMatrix,ColMajorMatrix,ColMajorMatrix>(lhsCol, rhsCol, resCol);
379
} // end namespace internal
381
template<typename Derived>
382
template<typename Lhs, typename Rhs>
383
inline void SparseMatrixBase<Derived>::_experimentalNewProduct(const Lhs& lhs, const Rhs& rhs)
385
//derived().resize(lhs.rows(), rhs.cols());
386
internal::sparse_product_selector2<
387
typename internal::remove_all<Lhs>::type,
388
typename internal::remove_all<Rhs>::type,
389
Derived>::run(lhs,rhs,derived());
393
template<typename Derived>
394
template<typename OtherDerived>
395
inline const typename SparseSparseProductReturnType<Derived,OtherDerived>::Type
396
SparseMatrixBase<Derived>::operator*(const SparseMatrixBase<OtherDerived> &other) const
398
return typename SparseSparseProductReturnType<Derived,OtherDerived>::Type(derived(), other.derived());
401
#endif // EIGEN_SPARSESPARSEPRODUCT_H