//------------------------------------------------------------------------------
// (c) 02-2002 Gottfried Chen
//------------------------------------------------------------------------------

#ifndef FFT_H
#define FFT_H

//#pragma once


#include "FftRoutines.h"


// Comment this out if you don't want to use FFTW but rather my own FFT
// routines (they are slower but are not bound by the FFTW open source
// license). You can also undefine this, if you run into troubles using
// FFTW.
// UNNNFINED TMI
#undef USE_FFTW
//#define USE_FFTW
//#define FFTW_ENABLE_FLOAT

#ifdef USE_FFTW
#   include <fftw.h>
#   include <rfftw.h>
#   ifndef FFTW_ENABLE_FLOAT
#       error You have add "#define FFTW_ENABLE_FLOAT" to fftw.h
#   endif
#endif



// Usage: Don't use any of the classes below!!!
//
// Use the classes:
// Fft, Fft2D, DFft and DFft2D.
//
// They are typedefed at the bottom of this file.
// This ensures that the correct FFT implementation will be used depending
// on the USE_FFTW #define.


// The array must contain 2^<sizePower> elements.
//template<class T>
class MyFft
{
public:
    MyFft(unsigned int sizePower, int sign = 1);
    void calculate(nComplex* data) const;

private:
    unsigned int mSizePower;
    int mSign;
};

// Works on an x*y array, where x is the number of elements in a row and y
// the number of elements in a column. The row index varries fastest, i.e.
// the array is an array of rows.
//template<class T>
class MyFft2D
{
public:
    MyFft2D(unsigned int xSizePower, unsigned int ySizePower, int sign = 1);
    void calculate(nComplex* data) const;

private:
    unsigned int mXSizePower;
    unsigned int mYSizePower;
    int mSign;
};

#ifdef USE_FFTW

    class Fftw
    {
    public:
        Fftw(unsigned int sizePower, int sign)
		{
		//    GE_ASSERT(sign == 1 || sign == -1);
			mPlan = fftw_create_plan(sizePower, sign < 0 ? FFTW_FORWARD : FFTW_BACKWARD,
									 FFTW_ESTIMATE|FFTW_IN_PLACE);
		}
        ~Fftw()
		{
			fftw_destroy_plan(mPlan);
		}

        void calculate(std::complex<float>* data) const
		{
			fftw_one(mPlan, reinterpret_cast<fftw_complex*>(data), 0);
		}

    private:
        fftw_plan mPlan;
    };

    class Fftw2D
    {
    public:
        Fftw2D(unsigned int xSizePower, unsigned int ySizePower, int sign = 1)
		{
		//    GE_ASSERT(sign == 1 || sign == -1);
			mPlan = fftw2d_create_plan(xSizePower, xSizePower,
                               sign < 0 ? FFTW_FORWARD : FFTW_BACKWARD,
                               FFTW_ESTIMATE|FFTW_IN_PLACE);
		}
        ~Fftw2D()
		{
			fftwnd_destroy_plan(mPlan);
		}

        void calculate(std::complex<float>* data) const
		{
			fftwnd_one(mPlan, reinterpret_cast<fftw_complex*>(data), 0);
		}

    private:
        fftwnd_plan mPlan;
    };

    typedef Fftw Fft;
    typedef Fftw2D Fft2D;

#else // USE_FFTW

    typedef MyFft Fft;
    typedef MyFft2D Fft2D;

#endif // USE_FFTW


// double versions of the above (no FFTW used for that now)
//typedef MyFft<double> DFft;
//typedef MyFft2D<double> DFft2D;


//#include "Fft.inl"
//template<class T> 
inline
//------------------------------------------------------------------------------
MyFft::MyFft(unsigned int sizePower, int sign) :
//------------------------------------------------------------------------------
mSizePower(sizePower),
//mSizePower(5),
mSign(sign)
{
}

//template<class T> 
inline
//------------------------------------------------------------------------------
void MyFft::calculate(nComplex* data) const
//------------------------------------------------------------------------------
{
    FftRoutines::fft(data, mSizePower, mSign);
}

//template<class T> 
inline
//------------------------------------------------------------------------------
MyFft2D::MyFft2D(unsigned int xSizePower, unsigned int ySizePower, int sign) :
//------------------------------------------------------------------------------
mXSizePower(xSizePower),
//mXSizePower(5),
mYSizePower(ySizePower),
//mYSizePower(5),
mSign(sign)
{
}

//template<class T> 
inline
//------------------------------------------------------------------------------
void MyFft2D::calculate(nComplex* data) const
//------------------------------------------------------------------------------
{
    FftRoutines::fft2D(data, mYSizePower, mXSizePower, mSign);
}

#endif // FFT_H