// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include "../tester.h"
#include <dlib/matrix.h>
#ifndef DLIB_USE_BLAS
#error "BLAS bindings must be used for this test to make any sense"
#endif
namespace dlib
{
namespace blas_bindings
{
// This is a little screwy. This function is used inside the BLAS
// bindings to count how many times each of the BLAS functions get called.
#ifdef DLIB_TEST_BLAS_BINDINGS
int& counter_ger() { static int counter = 0; return counter; }
#endif
}
}
namespace
{
using namespace test;
using namespace std;
// Declare the logger we will use in this test. The name of the logger
// should start with "test."
dlib::logger dlog("test.ger");
class blas_bindings_ger_tester : public tester
{
public:
blas_bindings_ger_tester (
) :
tester (
"test_ger", // the command line argument name for this test
"Run tests for GER routines.", // the command line argument description
0 // the number of command line arguments for this test
)
{}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_ger_stuff(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
counter_ger() = 0;
m2 = m + cv*rv;
DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
counter_ger() = 0;
m += trans(rv)*rv;
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*trans(cv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += cv*trans(cv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*rv + trans(cv*3*rv);
DLIB_TEST(counter_ger() == 2);
}
template <typename matrix_type, typename cv_type, typename rv_type>
void test_ger_stuff_conj(
matrix_type& m,
rv_type& rv,
cv_type& cv
) const
{
using namespace dlib;
using namespace dlib::blas_bindings;
rv_type rv2;
cv_type cv2;
matrix_type m2;
counter_ger() = 0;
m += cv*conj(rv);
DLIB_TEST_MSG(counter_ger() == 1, counter_ger());
counter_ger() = 0;
m += trans(rv)*conj(rv);
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*conj(trans(cv));
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += cv*trans(conj(cv));
DLIB_TEST(counter_ger() == 1);
counter_ger() = 0;
m += trans(rv)*rv + trans(cv*3*conj(rv));
DLIB_TEST(counter_ger() == 2);
}
void perform_test (
)
{
using namespace dlib;
typedef dlib::memory_manager<char>::kernel_1a mm;
dlog << dlib::LINFO << "test double";
{
matrix<double> m = randm(4,4);
matrix<double,1,0> rv = randm(1,4);
matrix<double,0,1> cv = randm(4,1);
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1> cv = matrix_cast<float>(randm(4,1));
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double> > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1> cv = complex_matrix(randm(4,1), randm(4,1));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float> > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test double";
{
matrix<double,0,0,mm,column_major_layout> m = randm(4,4);
matrix<double,1,0,mm,column_major_layout> rv = randm(1,4);
matrix<double,0,1,mm,column_major_layout> cv = randm(4,1);
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test float";
{
matrix<float,0,0,mm,column_major_layout> m = matrix_cast<float>(randm(4,4));
matrix<float,1,0,mm,column_major_layout> rv = matrix_cast<float>(randm(1,4));
matrix<float,0,1,mm,column_major_layout> cv = matrix_cast<float>(randm(4,1));
test_ger_stuff(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<double>";
{
matrix<complex<double>,0,0,mm,column_major_layout > m = complex_matrix(randm(4,4), randm(4,4));
matrix<complex<double>,1,0,mm,column_major_layout> rv = complex_matrix(randm(1,4), randm(1,4));
matrix<complex<double>,0,1,mm,column_major_layout> cv = complex_matrix(randm(4,1), randm(4,1));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
dlog << dlib::LINFO << "test complex<float>";
{
matrix<complex<float>,0,0,mm,column_major_layout > m = matrix_cast<complex<float> >(complex_matrix(randm(4,4), randm(4,4)));
matrix<complex<float>,1,0,mm,column_major_layout> rv = matrix_cast<complex<float> >(complex_matrix(randm(1,4), randm(1,4)));
matrix<complex<float>,0,1,mm,column_major_layout> cv = matrix_cast<complex<float> >(complex_matrix(randm(4,1), randm(4,1)));
test_ger_stuff(m,rv,cv);
test_ger_stuff_conj(m,rv,cv);
}
print_spinner();
}
};
blas_bindings_ger_tester a;
}