OpenMEEG
Loading...
Searching...
No Matches
symmatrix.h
Go to the documentation of this file.
1// Project Name: OpenMEEG (http://openmeeg.github.io)
2// © INRIA and ENPC under the French open source license CeCILL-B.
3// See full copyright notice in the file LICENSE.txt
4// If you make a copy of this file, you must either:
5// - provide also LICENSE.txt and modify this header to refer to it.
6// - replace this header by the LICENSE.txt content.
7
8#pragma once
9
10#include <iostream>
11#include <cstdlib>
12#include <string>
13
14#include <vector.h>
15#include <linop.h>
16
17namespace OpenMEEG {
18
19 class Matrix;
20
21 class OPENMEEGMATHS_EXPORT SymMatrix : public LinOp {
22
23 friend class Vector;
24
25 LinOpValue value;
26
27 public:
28
29 SymMatrix(): LinOp(0,0,SYMMETRIC,2),value() {}
30
31 SymMatrix(const char* fname): LinOp(0,0,SYMMETRIC,2),value() { this->load(fname); }
32 SymMatrix(Dimension N): LinOp(N,N,SYMMETRIC,2),value(size()) { }
33 SymMatrix(Dimension M,Dimension N): LinOp(N,N,SYMMETRIC,2),value(size()) { om_assert(N==M); }
34 SymMatrix(const SymMatrix& S,const DeepCopy): LinOp(S.nlin(),S.nlin(),SYMMETRIC,2),value(S.size(),S.data()) { }
35
36 explicit SymMatrix(const Vector& v);
37 explicit SymMatrix(const Matrix& A);
38
39 size_t size() const { return nlin()*(nlin()+1)/2; };
40 void info() const ;
41
42 Dimension ncol() const { return nlin(); } // SymMatrix only need num_lines
43 Dimension& ncol() { return nlin(); }
44
45 void alloc_data() { value = LinOpValue(size()); }
46 void reference_data(const double* array) { value = LinOpValue(size(),array); }
47
48 bool empty() const { return value.empty(); }
49 void set(double x) ;
50 double* data() const { return value.get(); }
51
52 double operator()(const Index i,const Index j) const {
53 om_assert(i<nlin());
54 om_assert(j<nlin());
55 return data()[(i<=j) ? i+j*(j+1)/2 : j+i*(i+1)/2];
56 }
57
58 double& operator()(const Index i,const Index j) {
59 om_assert(i<nlin());
60 om_assert(j<nlin());
61 return data()[(i<=j) ? i+j*(j+1)/2 : j+i*(i+1)/2];
62 }
63
64 Matrix operator()(const Index i_start,const Index i_end,const Index j_start,const Index j_end) const;
65 Matrix submat(const Index istart,const Index isize,const Index jstart,const Index jsize) const;
66 SymMatrix submat(const Index istart,const Index iend) const;
67 Vector getlin(const Index i) const;
68 void setlin(const Index i,const Vector& v);
69 Vector solveLin(const Vector& B) const;
70 void solveLin(Vector* B,const int nbvect);
72
73 const SymMatrix& operator=(const double d);
74
75 SymMatrix operator+(const SymMatrix& B) const;
76 SymMatrix operator-(const SymMatrix& B) const;
77 Matrix operator*(const SymMatrix& B) const;
78 Matrix operator*(const Matrix& B) const;
79 Vector operator*(const Vector& v) const;
80 SymMatrix operator*(const double x) const;
81 SymMatrix operator/(const double x) const { return (*this)*(1/x); }
82
83 void operator +=(const SymMatrix& B);
84 void operator -=(const SymMatrix& B);
85 void operator *=(const double x);
86 void operator /=(const double x) { (*this)*=(1/x); }
87
88 SymMatrix inverse() const;
89 void invert();
90 SymMatrix posdefinverse() const;
91 double det();
92 // void eigen(Matrix& Z,Vector& D);
93
94 void save(const char* filename) const;
95 void load(const char* filename);
96
97 void save(const std::string& s) const { save(s.c_str()); }
98 void load(const std::string& s) { load(s.c_str()); }
99
100 friend class Matrix;
101 };
102
103 // Returns the solution of (this)*X = B
104
105 inline Vector SymMatrix::solveLin(const Vector& B) const {
106 SymMatrix invA(*this,DEEP_COPY);
107 Vector X(B,DEEP_COPY);
108
109 #ifdef HAVE_LAPACK
110 // Bunch Kaufman factorization
111 BLAS_INT* pivots=new BLAS_INT[nlin()];
112 int Info = 0;
113 DSPTRF('U',sizet_to_int(invA.nlin()),invA.data(),pivots,Info);
114 om_assert(Info==0);
115
116 // Inverse
117 DSPTRS('U',sizet_to_int(invA.nlin()),1,invA.data(),pivots,X.data(),sizet_to_int(invA.nlin()),Info);
118 om_assert(Info==0);
119 delete[] pivots;
120 #else
121 std::cout << "solveLin not defined" << std::endl;
122 #endif
123 return X;
124 }
125
126 // stores in B the solution of (this)*X = B, where B is a set of nbvect vector
127
128 inline void SymMatrix::solveLin(Vector* B,const int nbvect) {
129 SymMatrix invA(*this,DEEP_COPY);
130
131 #ifdef HAVE_LAPACK
132 // Bunch Kaufman Factorization
133 BLAS_INT *pivots=new BLAS_INT[nlin()];
134 int Info = 0;
135 //char *uplo="U";
136 DSPTRF('U',sizet_to_int(invA.nlin()),invA.data(),pivots,Info);
137 om_assert(Info==0);
138 // Inverse
139 for(int i=0; i<nbvect; ++i) {
140 DSPTRS('U',sizet_to_int(invA.nlin()),1,invA.data(),pivots,B[i].data(),sizet_to_int(invA.nlin()),Info);
141 om_assert(Info==0);
142 }
143 delete[] pivots;
144 #else
145 std::cout << "solveLin not defined" << std::endl;
146 #endif
147 }
148
149 inline void SymMatrix::operator+=(const SymMatrix& B) {
150 om_assert(nlin()==B.nlin());
151 #ifdef HAVE_BLAS
152 BLAS(daxpy,DAXPY)(sizet_to_int(nlin()*(nlin()+1)/2), 1.0, B.data(), 1, data() , 1);
153 #else
154 const size_t sz = size();
155 for (size_t i=0; i<sz; ++i)
156 data()[i] += B.data()[i];
157 #endif
158 }
159
160 inline void SymMatrix::operator-=(const SymMatrix& B) {
161 om_assert(nlin()==B.nlin());
162 #ifdef HAVE_BLAS
163 BLAS(daxpy,DAXPY)(sizet_to_int(nlin()*(nlin()+1)/2), -1.0, B.data(), 1, data() , 1);
164 #else
165 const size_t sz = size();
166 for (size_t i=0; i<sz; ++i)
167 data()[i] -= B.data()[i];
168 #endif
169 }
170
172 // supposes (*this) is definite positive
173 SymMatrix invA(*this,DEEP_COPY);
174 #ifdef HAVE_LAPACK
175 // U'U factorization then inverse
176 int Info = 0;
177 DPPTRF('U', sizet_to_int(nlin()),invA.data(),Info);
178 om_assert(Info==0);
179 DPPTRI('U', sizet_to_int(nlin()),invA.data(),Info);
180 om_assert(Info==0);
181 #else
182 std::cerr << "Positive definite inverse not defined" << std::endl;
183 #endif
184 return invA;
185 }
186
187 inline double SymMatrix::det() {
188 SymMatrix invA(*this,DEEP_COPY);
189 double d = 1.0;
190 #ifdef HAVE_LAPACK
191 // Bunch Kaufmqn
192 BLAS_INT *pivots=new BLAS_INT[nlin()];
193 int Info = 0;
194 // TUDUtTt
195 DSPTRF('U', sizet_to_int(invA.nlin()), invA.data(), pivots,Info);
196 if (Info<0) {
197 std::cout << "Big problem in det (DSPTRF)" << std::endl;
198 om_assert(Info==0);
199 }
200 for (size_t i = 0; i< nlin(); i++){
201 if (pivots[i] >= 0) {
202 d *= invA(i,i);
203 } else { // pivots[i] < 0
204 if (i < nlin()-1 && pivots[i] == pivots[i+1]) {
205 d *= (invA(i,i)*invA(i+1,i+1)-invA(i,i+1)*invA(i+1,i));
206 i++;
207 } else {
208 std::cout << "Big problem in det" << std::endl;
209 }
210 }
211 }
212 delete[] pivots;
213 #else
214 throw OpenMEEG::maths::LinearAlgebraError("Determinant not defined without LAPACK");
215 #endif
216 return(d);
217 }
218
219 // inline void SymMatrix::eigen(Matrix& Z,Vector& D ){
220 // // performs the complete eigen-decomposition.
221 // // (*this) = Z.D.Z'
222 // // -> eigenvector are columns of the Matrix Z.
223 // // (*this).Z[:,i] = D[i].Z[:,i]
224 // #ifdef HAVE_LAPACK
225 // SymMatrix symtemp(*this,DEEP_COPY);
226 // D = Vector(nlin());
227 // Z = Matrix(nlin(),nlin());
228 //
229 // int info;
230 // double lworkd;
231 // int lwork;
232 // int liwork;
233 //
234 // DSPEVD('V','U',sizet_to_int(nlin()),symtemp.data(),D.data(),Z.data(),sizet_to_int(nlin()),&lworkd,-1,&liwork,-1,info);
235 // lwork = (int) lworkd;
236 // double * work = new double[lwork];
237 // BLAS_INT *iwork = new BLAS_INT[liwork];
238 // DSPEVD('V','U',sizet_to_int(nlin()),symtemp.data(),D.data(),Z.data(),sizet_to_int(nlin()),work,lwork,iwork,liwork,info);
239 //
240 // delete[] work;
241 // delete[] iwork;
242 // #endif
243 // }
244
245 inline SymMatrix SymMatrix::operator+(const SymMatrix& B) const {
246 om_assert(nlin()==B.nlin());
247 SymMatrix C(*this,DEEP_COPY);
248 C += B;
249 return C;
250 }
251
252 inline SymMatrix SymMatrix::operator-(const SymMatrix& B) const {
253 om_assert(nlin()==B.nlin());
254 SymMatrix C(*this,DEEP_COPY);
255 C -= B;
256 return C;
257 }
258
260 #ifdef HAVE_LAPACK
261 SymMatrix invA(*this, DEEP_COPY);
262 // LU
263 BLAS_INT* pivots = new BLAS_INT[nlin()];
264 int Info = 0;
265 const BLAS_INT M = sizet_to_int(nlin());
266 DSPTRF('U',M,invA.data(),pivots,Info);
267 om_assert(Info==0);
268 // Inverse
269 double* work = new double[nlin()*64];
270 DSPTRI('U',M,invA.data(),pivots,work,Info);
271 om_assert(Info==0);
272 delete[] pivots;
273 delete[] work;
274 return invA;
275 #else
276 throw OpenMEEG::maths::LinearAlgebraError("Inverse not implemented, requires LAPACK");
277 #endif
278 }
279
280 inline void SymMatrix::invert() {
281 #ifdef HAVE_LAPACK
282 // LU
283 BLAS_INT* pivots = new BLAS_INT[nlin()];
284 int Info = 0;
285 const BLAS_INT M = sizet_to_int(nlin());
286 DSPTRF('U',M,data(),pivots,Info);
287 om_assert(Info==0);
288
289 // Inverse
290 double* work = new double[nlin()*64];
291 DSPTRI('U',M,data(),pivots,work,Info);
292 om_assert(Info==0);
293
294 delete[] pivots;
295 delete[] work;
296 return;
297 #else
298 throw OpenMEEG::maths::LinearAlgebraError("Inverse not implemented, requires LAPACK");
299 #endif
300 }
301
302 inline Vector SymMatrix::operator*(const Vector& v) const {
303 om_assert(nlin()==v.size());
304 Vector y(nlin());
305 #ifdef HAVE_BLAS
306 const BLAS_INT M = sizet_to_int(nlin());
307 DSPMV(CblasUpper,M,1.0,data(),v.data(),1,0.0,y.data(),1);
308 #else
309 for (Index i=0; i<nlin(); ++i) {
310 y(i)=0;
311 for (Index j=0; j<nlin(); ++j)
312 y(i)+=(*this)(i,j)*v(j);
313 }
314 #endif
315 return y;
316 }
317
318 inline Vector SymMatrix::getlin(const Index i) const {
319 om_assert(i<nlin());
320 Vector v(ncol());
321 for (Index j=0; j<ncol(); ++j)
322 v(j) = (*this)(i,j);
323 return v;
324 }
325
326 inline void SymMatrix::setlin(const Index i,const Vector& v) {
327 om_assert(v.size()==nlin());
328 om_assert(i<nlin());
329 for (Index j=0; j<ncol(); ++j)
330 (*this)(i,j) = v(j);
331 }
332}
Dimension nlin() const
Definition linop.h:48
Matrix class Matrix class.
Definition matrix.h:28
void save(const std::string &s) const
Definition symmatrix.h:97
Matrix submat(const Index istart, const Index isize, const Index jstart, const Index jsize) const
SymMatrix posdefinverse() const
Definition symmatrix.h:171
Matrix operator*(const Matrix &B) const
SymMatrix operator/(const double x) const
Definition symmatrix.h:81
SymMatrix(const char *fname)
Definition symmatrix.h:31
friend class Matrix
Definition symmatrix.h:100
Vector solveLin(const Vector &B) const
Definition symmatrix.h:105
Matrix solveLin(Matrix &B) const
const SymMatrix & operator=(const double d)
Matrix operator*(const SymMatrix &B) const
SymMatrix operator*(const double x) const
Dimension ncol() const
Definition symmatrix.h:42
void operator+=(const SymMatrix &B)
Definition symmatrix.h:149
size_t size() const
Definition symmatrix.h:39
SymMatrix submat(const Index istart, const Index iend) const
double * data() const
Definition symmatrix.h:50
SymMatrix operator+(const SymMatrix &B) const
Definition symmatrix.h:245
Matrix operator()(const Index i_start, const Index i_end, const Index j_start, const Index j_end) const
void info() const
friend class Vector
Definition symmatrix.h:23
void setlin(const Index i, const Vector &v)
Definition symmatrix.h:326
double & operator()(const Index i, const Index j)
Definition symmatrix.h:58
Dimension & ncol()
Definition symmatrix.h:43
Vector getlin(const Index i) const
Definition symmatrix.h:318
void save(const char *filename) const
SymMatrix inverse() const
Definition symmatrix.h:259
void reference_data(const double *array)
Definition symmatrix.h:46
bool empty() const
Definition symmatrix.h:48
SymMatrix(const SymMatrix &S, const DeepCopy)
Definition symmatrix.h:34
double operator()(const Index i, const Index j) const
Definition symmatrix.h:52
void load(const char *filename)
SymMatrix(const Matrix &A)
SymMatrix(const Vector &v)
SymMatrix(Dimension N)
Definition symmatrix.h:32
void load(const std::string &s)
Definition symmatrix.h:98
void set(double x)
void operator-=(const SymMatrix &B)
Definition symmatrix.h:160
SymMatrix(Dimension M, Dimension N)
Definition symmatrix.h:33
SymMatrix operator-(const SymMatrix &B) const
Definition symmatrix.h:252
size_t size() const
Definition vector.h:40
double * data() const
Definition vector.h:44
DeepCopy
Definition linop.h:84
@ DEEP_COPY
Definition linop.h:84
unsigned Dimension
Definition linop.h:32
double det(const Vect3 &V1, const Vect3 &V2, const Vect3 &V3)
Definition vect3.h:108
unsigned Index
Definition linop.h:33
BLAS_INT sizet_to_int(const unsigned &num)
Definition linop.h:26