/*
 * fftw_wrapper.hh
 *
 *  Created on: 2013. 10. 2.
 *      Author: parkmh
 */

#ifndef FFTW_WRAPPER_HH_
#define FFTW_WRAPPER_HH_

#include "../mesh/point.hh"
#include "../math/numvec.hh"
#include <fftw3.h>
#include <complex>
typedef std::complex<double> Complex;

template <int dim>
class FFT_R2C{
	typedef Point<dim,size_t> N;
public:
	FFT_R2C();
	~FFT_R2C();
	void init(NumVec<double>&,const N&);
	void execute();
	int outSize() const;
	void real(NumVec<double> &);

private:
	N n_;
	fftw_plan plan;
	NumVec<Complex> out_;
};

template <int dim>
FFT_R2C<dim>::FFT_R2C(){
	plan = 0;
}

template <int dim>
FFT_R2C<dim>::~FFT_R2C(){
	fftw_destroy_plan(plan);
	fftw_cleanup();
}


template <int dim>
void FFT_R2C<dim>::init(NumVec<double>& in,const N& n){
	n_ = n;
	out_.resize((size_t)outSize());
	fftw_destroy_plan(plan);
	switch (dim){
	case 1:
		plan = fftw_plan_dft_r2c_1d(
				n_[0],
				in.begin(),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				FFTW_ESTIMATE
		);
		break;
	case 2:
		plan = fftw_plan_dft_r2c_2d(
				n_[1],
				n_[0],
				in.begin(),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				FFTW_ESTIMATE
		);
		break;
	case 3:
		plan = fftw_plan_dft_r2c_3d(
				n_[2],
				n_[1],
				n_[0],
				in.begin(),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				FFTW_ESTIMATE
		);
		break;
	}
}
template <int dim>
void FFT_R2C<dim>::execute(){
	fftw_execute( plan );
}

template <int dim>
void FFT_R2C<dim>::real(NumVec<double>& rOut){
	size_t oSize = outSize();
	for (size_t i = 0; i < oSize; i++){
		rOut[i] = out_[i].real();
	}
}
template <int dim>
int FFT_R2C<dim>::outSize() const{
	int size = n_[0]/2 + 1;
	for (int i = 1;i < dim; i++){
		size*=n_[i];
	}
	return size;
}
/*
 *
 */
template <int dim>
class FFT{
	typedef Point<dim,size_t> N;
public:
	FFT();
	~FFT();
	void init(NumVec<Complex>&, const N&, int);
	void execute();
	int outSize() const;
	void real(NumVec<double> &);
	void imag(NumVec<double> &);
	void real(double *);
	void imag(double *);
	const NumVec<Complex>& result() const;

private:
	N n_;
	fftw_plan plan;
	NumVec<Complex> out_;
	int direction_;
};

template <int dim>
FFT<dim>::FFT(){
	plan = 0;
	direction_ = FFTW_FORWARD;
}

template <int dim>
FFT<dim>::~FFT(){
	if (plan) fftw_destroy_plan(plan);
	fftw_cleanup();
}


template <int dim>
void FFT<dim>::init(NumVec<Complex>& in,const N& n, int direction = FFTW_FORWARD){
	direction_ = direction;
	n_ = n;
	out_.resize((size_t)outSize());
	fftw_destroy_plan(plan);
	switch (dim){
	case 1:
		plan = fftw_plan_dft_1d (
				n_[0],
				reinterpret_cast<fftw_complex*>(in.begin()),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				direction,
				FFTW_ESTIMATE
		);
		break;
	case 2:
		plan = fftw_plan_dft_2d (
				n_[1],
				n_[0],
				reinterpret_cast<fftw_complex*>(in.begin()),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				direction,
				FFTW_ESTIMATE
		);
		break;
	case 3:
		plan = fftw_plan_dft_3d(
				n_[2],
				n_[1],
				n_[0],
				reinterpret_cast<fftw_complex*>(in.begin()),
				reinterpret_cast<fftw_complex*>(out_.begin()),
				direction,
				FFTW_ESTIMATE
		);
		break;
	}

}
template <int dim>
void FFT<dim>::execute(){
	fftw_execute( plan );

	if (direction_ == FFTW_BACKWARD){
//		std::cout << "Normalising" << std::endl;
		size_t os = (size_t)outSize();
		double normal_factor = n_.prod();
		for (size_t i = 0; i < os; i++){
			out_[i]/= normal_factor;
		}
	}

}

template <int dim>
int FFT<dim>::outSize() const{
	int size = n_[0];
	for (int i = 1;i < dim; i++){
		size*=n_[i];
	}
	return size;
}

template <int dim>
void FFT<dim>::real(NumVec<double>& rOut){
	size_t oSize = outSize();
	for (size_t i = 0; i < oSize; i++){
		rOut[i] = out_[i].real();
	}
}

template <int dim>
void FFT<dim>::imag(NumVec<double>& rOut){
	size_t oSize = outSize();
	for (size_t i = 0; i < oSize; i++){
		rOut[i] = out_[i].imag();
	}
}

template <int dim>
void FFT<dim>::real(double * rOut){
	size_t oSize = outSize();
	for (size_t i = 0; i < oSize; i++){
		rOut[i] = out_[i].real();
	}
}

template <int dim>
void FFT<dim>::imag(double *rOut){
	size_t oSize = outSize();
	for (size_t i = 0; i < oSize; i++){
		rOut[i] = out_[i].imag();
	}
}

template <int dim>
const NumVec<Complex>& FFT<dim>::result() const {
	return out_;
}

#endif /* FFTW_WRAPPER_HH_ */
