Browse Source

FFT-MKL: generalize 2D case to arbitrary multiple dimensions

pull/445/head
Christoph Ruegg 10 years ago
parent
commit
d7bc61240c
  1. 4
      src/Benchmark/Program.cs
  2. 15
      src/NativeProviders/MKL/fft.cpp
  3. 4
      src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs
  4. 3
      src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs
  5. 13
      src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs
  6. 88
      src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs

4
src/Benchmark/Program.cs

@ -17,8 +17,8 @@ namespace Benchmark
var config = ManualConfig.Create(DefaultConfig.Instance) var config = ManualConfig.Create(DefaultConfig.Instance)
.With(new MemoryDiagnoser()); //, new InliningDiagnoser()); .With(new MemoryDiagnoser()); //, new InliningDiagnoser());
//BenchmarkRunner.Run<Transforms.FFT>(config); BenchmarkRunner.Run<Transforms.FFT>(config);
BenchmarkRunner.Run<LinearAlgebra.DenseMatrixProduct>(config); //BenchmarkRunner.Run<LinearAlgebra.DenseMatrixProduct>(config);
//Benchmark(new LinearAlgebra.DenseVectorAdd(10000000,1), 10, "Large (10'000'000) - 10x1 iterations"); //Benchmark(new LinearAlgebra.DenseVectorAdd(10000000,1), 10, "Large (10'000'000) - 10x1 iterations");
//Benchmark(new LinearAlgebra.DenseVectorAdd(100,1000), 100, "Small (100) - 100x1000 iterations"); //Benchmark(new LinearAlgebra.DenseVectorAdd(100,1000), 100, "Small (100) - 100x1000 iterations");

15
src/NativeProviders/MKL/fft.cpp

@ -23,12 +23,9 @@ inline MKL_LONG fft_create_1d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG n,
} }
template<typename Precision> template<typename Precision>
inline MKL_LONG fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain) inline MKL_LONG fft_create_md(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const Precision forward_scale, const Precision backward_scale, const DFTI_CONFIG_VALUE precision, const DFTI_CONFIG_VALUE domain)
{ {
MKL_LONG sizes[2]; MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, dimensions, n);
sizes[0] = m;
sizes[1] = n;
MKL_LONG status = DftiCreateDescriptor(handle, precision, domain, 2, sizes);
DFTI_DESCRIPTOR_HANDLE descriptor = *handle; DFTI_DESCRIPTOR_HANDLE descriptor = *handle;
if (0 == status) status = DftiSetValue(descriptor, DFTI_FORWARD_SCALE, forward_scale); if (0 == status) status = DftiSetValue(descriptor, DFTI_FORWARD_SCALE, forward_scale);
if (0 == status) status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, backward_scale); if (0 == status) status = DftiSetValue(descriptor, DFTI_BACKWARD_SCALE, backward_scale);
@ -59,14 +56,14 @@ extern "C" {
return fft_create_1d(handle, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX); return fft_create_1d(handle, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX);
} }
DLLEXPORT MKL_LONG z_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const double forward_scale, const double backward_scale) DLLEXPORT MKL_LONG z_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const double forward_scale, const double backward_scale)
{ {
return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX); return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_DOUBLE, DFTI_COMPLEX);
} }
DLLEXPORT MKL_LONG c_fft_create_2d(DFTI_DESCRIPTOR_HANDLE* handle, const MKL_LONG m, const MKL_LONG n, const float forward_scale, const float backward_scale) DLLEXPORT MKL_LONG c_fft_create_multidim(DFTI_DESCRIPTOR_HANDLE* handle, MKL_LONG dimensions, MKL_LONG n[], const float forward_scale, const float backward_scale)
{ {
return fft_create_2d(handle, m, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX); return fft_create_md(handle, dimensions, n, forward_scale, backward_scale, DFTI_SINGLE, DFTI_COMPLEX);
} }
DLLEXPORT MKL_LONG z_fft_forward(const DFTI_DESCRIPTOR_HANDLE handle, MKL_Complex16 x[]) DLLEXPORT MKL_LONG z_fft_forward(const DFTI_DESCRIPTOR_HANDLE handle, MKL_Complex16 x[])

4
src/Numerics/Providers/Common/Mkl/SafeNativeMethods.cs

@ -393,10 +393,10 @@ namespace MathNet.Numerics.Providers.Common.Mkl
internal static extern int c_fft_create([Out] out IntPtr handle, int n, float forward_scale, float backward_scale); internal static extern int c_fft_create([Out] out IntPtr handle, int n, float forward_scale, float backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int z_fft_create_2d([Out] out IntPtr handle, int m, int n, double forward_scale, double backward_scale); internal static extern int z_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, double forward_scale, double backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int c_fft_create_2d([Out] out IntPtr handle, int m, int n, float forward_scale, float backward_scale); internal static extern int c_fft_create_multidim([Out] out IntPtr handle, int dimensions, [In] int[] n, float forward_scale, float backward_scale);
[DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)] [DllImport(_DllName, ExactSpelling = true, SetLastError = false, CallingConvention = CallingConvention.Cdecl)]
internal static extern int z_fft_forward([In] IntPtr handle, [In, Out] Complex[] x); internal static extern int z_fft_forward([In] IntPtr handle, [In, Out] Complex[] x);

3
src/Numerics/Providers/FourierTransform/IFourierTransformProvider.cs

@ -60,5 +60,8 @@ namespace MathNet.Numerics.Providers.FourierTransform
Complex[] Forward(Complex[] complexTimeSpace, FourierTransformScaling scaling); Complex[] Forward(Complex[] complexTimeSpace, FourierTransformScaling scaling);
Complex[] Backward(Complex[] complexFrequenceSpace, FourierTransformScaling scaling); Complex[] Backward(Complex[] complexFrequenceSpace, FourierTransformScaling scaling);
void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling);
void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling);
} }
} }

13
src/Numerics/Providers/FourierTransform/ManagedFourierTransformProvider.cs

@ -26,12 +26,11 @@
// OTHER DEALINGS IN THE SOFTWARE. // OTHER DEALINGS IN THE SOFTWARE.
// </copyright> // </copyright>
using System.Collections; using System;
using MathNet.Numerics.IntegralTransforms; using MathNet.Numerics.IntegralTransforms;
namespace MathNet.Numerics.Providers.FourierTransform namespace MathNet.Numerics.Providers.FourierTransform
{ {
#if !NOSYSNUMERICS #if !NOSYSNUMERICS
using Complex = System.Numerics.Complex; using Complex = System.Numerics.Complex;
#endif #endif
@ -107,5 +106,15 @@ namespace MathNet.Numerics.Providers.FourierTransform
BackwardInplace(work, scaling); BackwardInplace(work, scaling);
return work; return work;
} }
public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
} }
} }

88
src/Numerics/Providers/FourierTransform/Mkl/MklFourierTransformProvider.cs

@ -40,7 +40,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
class Kernel class Kernel
{ {
public IntPtr Handle; public IntPtr Handle;
public int Length; public int[] Dimensions;
public FourierTransformScaling Scaling; public FourierTransformScaling Scaling;
} }
@ -63,8 +63,8 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
MklProvider.Load(minRevision: 11); MklProvider.Load(minRevision: 11);
// we only support exactly one major version, since major version changes imply a breaking change. // we only support exactly one major version, since major version changes imply a breaking change.
int fftMajor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMajor); int fftMajor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMajor);
int fftMinor = SafeNativeMethods.query_capability((int)ProviderCapability.FourierTransformMinor); int fftMinor = SafeNativeMethods.query_capability((int) ProviderCapability.FourierTransformMinor);
if (!(fftMajor == 1 && fftMinor >= 0)) if (!(fftMajor == 1 && fftMinor >= 0))
{ {
throw new NotSupportedException(string.Format("MKL Native Provider not compatible. Expecting fourier transform v1 but provider implements v{0}.", fftMajor)); throw new NotSupportedException(string.Format("MKL Native Provider not compatible. Expecting fourier transform v1 but provider implements v{0}.", fftMajor));
@ -150,18 +150,78 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
{ {
kernel = new Kernel kernel = new Kernel
{ {
Length = length, Dimensions = new[] {length},
Scaling = scaling Scaling = scaling
}; };
SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
return kernel; return kernel;
} }
if (kernel.Length != length || kernel.Scaling != scaling) if (kernel.Dimensions.Length != 1 || kernel.Dimensions[0] != length || kernel.Scaling != scaling)
{ {
SafeNativeMethods.x_fft_free(ref kernel.Handle); SafeNativeMethods.x_fft_free(ref kernel.Handle);
SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length)); SafeNativeMethods.z_fft_create(out kernel.Handle, length, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
kernel.Length = length; kernel.Dimensions = new[] {length};
kernel.Scaling = scaling;
return kernel;
}
return kernel;
}
Kernel Configure(int[] dimensions, FourierTransformScaling scaling)
{
if (dimensions.Length == 1)
{
return Configure(dimensions[0], scaling);
}
Kernel kernel = Interlocked.Exchange(ref _kernel, null);
if (kernel == null)
{
kernel = new Kernel
{
Dimensions = dimensions,
Scaling = scaling
};
long length = 1;
for (int i = 0; i < dimensions.Length; i++)
{
length *= dimensions[i];
}
SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
return kernel;
}
bool mismatch = kernel.Dimensions.Length != dimensions.Length || kernel.Scaling != scaling;
if (!mismatch)
{
for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] != kernel.Dimensions[i])
{
mismatch = true;
break;
}
}
}
if (mismatch)
{
long length = 1;
for (int i = 0; i < dimensions.Length; i++)
{
length *= dimensions[i];
}
SafeNativeMethods.x_fft_free(ref kernel.Handle);
SafeNativeMethods.z_fft_create_multidim(out kernel.Handle, dimensions.Length, dimensions, ForwardScaling(scaling, length), BackwardScaling(scaling, length));
kernel.Dimensions = dimensions;
kernel.Scaling = scaling; kernel.Scaling = scaling;
return kernel; return kernel;
} }
@ -208,7 +268,17 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
return work; return work;
} }
static double ForwardScaling(FourierTransformScaling scaling, int length) public void ForwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
public void BackwardInplaceMultidim(Complex[] complex, int[] dimensions, FourierTransformScaling scaling)
{
throw new NotImplementedException();
}
static double ForwardScaling(FourierTransformScaling scaling, long length)
{ {
switch (scaling) switch (scaling)
{ {
@ -221,7 +291,7 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
} }
} }
static double BackwardScaling(FourierTransformScaling scaling, int length) static double BackwardScaling(FourierTransformScaling scaling, long length)
{ {
switch (scaling) switch (scaling)
{ {
@ -241,4 +311,4 @@ namespace MathNet.Numerics.Providers.FourierTransform.Mkl
} }
} }
#endif #endif
Loading…
Cancel
Save