Browse Source

FFT-MKL: generalize 2D case to arbitrary multiple dimensions

pull/445/head
Christoph Ruegg 10 years ago
parent
commit
d7bc61240c
  1. 4
      src/Benchmark/Program.cs
  2. 15
      src/NativeProviders/MKL/fft.cpp
  3. 4
      src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs
  4. 3
      src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs
  5. 13
      src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs
  6. 88
      src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs

4
src/Benchmark/Program.cs

@ -17,8 +17,8 @@ namespace Benchmark
var config = ManualConfig.Create(DefaultConfig.Instance)
.With(new MemoryDiagnoser()); //, new InliningDiagnoser());
//BenchmarkRunner.Run<Transforms.FFT>(config);
BenchmarkRunner.Run<LinearAlgebra.DenseMatrixProduct>(config);
BenchmarkRunner.Run<Transforms.FFT>(config);
//BenchmarkRunner.Run<LinearAlgebra.DenseMatrixProduct>(config);
//Benchmark(new LinearAlgebra.DenseVectorAdd(10000000,1), 10, "Large (10'000'000) - 10x1 iterations");
//Benchmark(new LinearAlgebra.DenseVectorAdd(100,1000), 100, "Small (100) - 100x1000 iterations");

15
src/NativeProviders/MKL/fft.cpp

@ -23,12 +23,9 @@ inline MKL_LONG fft_create_1d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG n,
}
template<typename Precision>
inline MKL_LONG fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain)
inline MKL_LONG fft_create_md(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain)
{
MKL_LONG sizes[2];
sizes[0] = m;
sizes[1] = n;
MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, 2, sizes);
MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, dimensions, n);
DFTI_DESCRIPTOR_HANDLE descriptor = *handle;
if (0 == status) status = DftiSetValue(descriptor, DFTI_FORWARD_SCALE, forward_scale);
if (0 == status) status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, backward_scale);
@ -59,14 +56,14 @@ extern "C" {
return fft_create_1d(handle, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX);
}
DLLEXPORT MKL_LONG z_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const double forward_scale, const double backward_scale)
DLLEXPORT MKL_LONG z_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const double forward_scale, const double backward_scale)
{
return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX);
return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX);
}
DLLEXPORT MKL_LONG c_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const float forward_scale, const float backward_scale)
DLLEXPORT MKL_LONG c_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const float forward_scale, const float backward_scale)
{
return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX);
return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX);
}
DLLEXPORT MKL_LONG z_fft_forward(const DFTI_DESCRIPTOR_HANDLE handle, MKL_Complex16 x[])

4
src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs

@ -393,10 +393,10 @@ namespace MathNet.Numerics.Providers.Common.Mkl
internal static extern int c_fft_create([Out] out IntPtr handle, int n, float forward_scale, float backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int z_fft_create_2d([Out] out IntPtr handle, int m, int n, double forward_scale, double backward_scale);
internal static extern int z_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, double forward_scale, double backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int c_fft_create_2d([Out] out IntPtr handle, int m, int n, float forward_scale, float backward_scale);
internal static extern int c_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, float forward_scale, float backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int z_fft_forward([In] IntPtr handle, [In, Out] Complex[] x);

3
src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs

@ -60,5 +60,8 @@ namespace MathNet.Numerics.Providers.FourierTransform
Complex[] Forward(Complex[] complexTimeSpace, FourierTransformScaling scaling);
Complex[] Backward(Complex[] complexFrequenceSpace, FourierTransformScaling scaling);
void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling);
void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling);
}
}

13
src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs

@ -26,12 +26,11 @@
// OTHER DEALINGS IN THE SOFTWARE.
// </copyright>
using System.Collections;
using System;
using MathNet.Numerics.IntegralTransforms;
namespace MathNet.Numerics.Providers.FourierTransform
{
#if !NOSYSNUMERICS
using Complex = System.Numerics.Complex;
#endif
@ -107,5 +106,15 @@ namespace MathNet.Numerics.Providers.FourierTransform
BackwardInplace(work, scaling);
return work;
}
public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
}
}

88
src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs

@ -40,7 +40,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
class Kernel
{
public IntPtr Handle;
public int Length;
public int[] Dimensions;
public FourierTransformScaling Scaling;
}
@ -63,8 +63,8 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
MklProvider.Load(minRevision: 11);
// we only support exactly one major version, since major version changes imply a breaking change.
int fftMajor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMajor);
int fftMinor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMinor);
int fftMajor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMajor);
int fftMinor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMinor);
if (!(fftMajor == 1 && fftMinor >= 0))
{
throw new NotSupportedException(string.Format("MKL Native Provider not compatible. Expecting fourier transform v1 but provider implements v{0}.", fftMajor));
@ -150,18 +150,78 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
{
kernel = new Kernel
{
Length = length,
Dimensions = new[] {length},
Scaling = scaling
};
SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
return kernel;
}
if (kernel.Length != length || kernel.Scaling != scaling)
if (kernel.Dimensions.Length != 1 || kernel.Dimensions[0] != length || kernel.Scaling != scaling)
{
SafeNativeMethods.x_fft_free(ref kernel.Handle);
SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
kernel.Length = length;
kernel.Dimensions = new[] {length};
kernel.Scaling = scaling;
return kernel;
}
return kernel;
}
Kernel Configure(int[] dimensions, FourierTransformScaling scaling)
{
if (dimensions.Length == 1)
{
return Configure(dimensions[0], scaling);
}
Kernel kernel = Interlocked.Exchange(ref _kernel, null);
if (kernel == null)
{
kernel = new Kernel
{
Dimensions = dimensions,
Scaling = scaling
};
long length = 1;
for (int i = 0; i < dimensions.Length; i++)
{
length *= dimensions[i];
}
SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
return kernel;
}
bool mismatch = kernel.Dimensions.Length != dimensions.Length || kernel.Scaling != scaling;
if (!mismatch)
{
for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] != kernel.Dimensions[i])
{
mismatch = true;
break;
}
}
}
if (mismatch)
{
long length = 1;
for (int i = 0; i < dimensions.Length; i++)
{
length *= dimensions[i];
}
SafeNativeMethods.x_fft_free(ref kernel.Handle);
SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
kernel.Dimensions = dimensions;
kernel.Scaling = scaling;
return kernel;
}
@ -208,7 +268,17 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
return work;
}
static double ForwardScaling(FourierTransformScaling scaling, int length)
public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
static double ForwardScaling(FourierTransformScaling scaling, long length)
{
switch (scaling)
{
@ -221,7 +291,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
}
}
static double BackwardScaling(FourierTransformScaling scaling, int length)
static double BackwardScaling(FourierTransformScaling scaling, long length)
{
switch (scaling)
{
@ -241,4 +311,4 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
}
}
#endif
#endif
Loading…
Cancel
Save