/*
 * ilu.hh
 *
 *  Created on: 2013. 8. 30.
 *      Author: parkmh
 */

#ifndef ILU_HH_
#define ILU_HH_

#include "numvec.hh"
#include "crs_matrix.hh"

template <class T> class ILU{
public:
	typedef T datatype;
	typedef CRSMatrix<T> matrix;

	ILU(){};
	ILU(const CRSMatrix<T>&);
	void init(const CRSMatrix<T> &);
	void solve(const NumVec<T> &, NumVec<T> &) const;
	friend std::ostream&
	operator<<(std::ostream& os, const ILU & ilu){
		os << "Incomplete LU Decomposition";
		return os;
	}
	void overview(std::ostream&) const;
private:
	CRSMatrix<T> L_;
	NumVec<int> D_;
	void factor(const CRSMatrix<T>&);
};

template <class T>
ILU<T>::ILU(const CRSMatrix<T> &A){
	init(A);
}

template <class T>
void ILU<T>::init(const CRSMatrix<T> &A) {
	factor(A);
}

template <class T>
void ILU<T>::factor(const CRSMatrix<T>& A){
	size_t i,j,jj;
	size_t n = A.row();
	size_t iai, iaip1, jaj;
	int jw;
	T tl;
	D_ = A.diag();
	L_ = A;		// copy A
	NumVec<int> iw;	// iw points to the nonzero entries in row l.
	iw.resize(n);
	int* iwp = iw.begin();

	const size_t * iap = A.iabegin();
	const size_t * jap = A.jabegin();
	datatype *laap = L_.aabegin();
	int * dp = D_.begin();
	iai = *iap;
	iw = -1;
	for (i = 0; i < n; i++){
		iw = -1;
		iaip1 = *(iap+i+1);
		for (j = iai; j < iaip1 ; j++){
			*(iwp+*(jap+j)) = (int)j;
		}
		for (j = iai; j < iaip1; j++){
			jaj = *(jap+j);
			if ( i <= jaj ){
				break;
			}

			tl = *(laap+j) * *(laap + *(dp+jaj));
			*(laap+j) = tl;

			for (jj = (size_t)*(dp+jaj)+1; jj < *(iap+jaj+1); jj++){
				jw = *(iwp+*(jap + jj));

				if (jw != -1){
					*(laap+jw) -= tl * *(laap+jj);
				}
			}
		}


		if (jaj != i){
			throw std::runtime_error("ILU Fatal error!");
		}
		if (*(laap+j) == 0.0){
			throw std::runtime_error("ILU Fatal error - zero pivot");
		}
		*(laap+j) = 1.0 / *(laap+j);

		for (j = iai; j < iaip1 ; j++){
			*(iwp+*(jap+j)) = -1;
		}
		iai = iaip1;
	}

	for (i = 0; i < n; i++){
		*(laap+*(dp+i)) = 1.0 / *(laap+*(dp+i));
	}

}

template<class T>
void ILU<T>::solve(const NumVec<T>& rhs, NumVec<T>& x) const{
	size_t n = L_.row();
	x = rhs;
	size_t i, j, iai, iaip1, dpi;
	const int *dp = D_.begin();
	const size_t *liap = L_.iabegin();
	const size_t *ljap = L_.jabegin();
	const T* laap = L_.aabegin();
	T* xp = x.begin();
	/*
	 * solve L*x  = x	where L is the lower triangular matrix.
	 */
	iai = *(liap+1);
	for (i = 1; i < n; i++){
		iaip1 = *(liap+i+1);
		dpi = (size_t)*(dp+i);
		for (j = iai; j < dpi; j++){
			*(xp+i) -= *(laap+j)* *(xp+*(ljap+j));
		}
		iai  = iaip1;
	}
	std::cout << rhs << std::endl;
	std::cout << x << std::endl;
	/*
	 * solve U*x = x;
	 */
	for (int ii = n-1; ii >= 0; ii--){
		for (j = *(dp+ii)+1; j < *(liap+ii+1); j++){
			*(xp+ii) -= *(laap+j)* *(xp+*(ljap+j));
		}
//		std::cout << ii << ", " << *(laap+*(dp+ii)) << std::endl;
		*(xp+ii) /= *(laap+*(dp+ii));
	}
}

template <class T>
void ILU<T>::overview(std::ostream &os ) const{
	os << "ILU" ;
}


#endif /* ILU_HH_ */
