added smoke test on EM, fixed EM reading #1570 (thanks to mr.pppoe),
This commit is contained in:
parent
ec793df30f
commit
8d9d964550
@ -141,8 +141,6 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
|
|||||||
CvFileNode* em_node = 0;
|
CvFileNode* em_node = 0;
|
||||||
CvFileNode* tmp_node = 0;
|
CvFileNode* tmp_node = 0;
|
||||||
CvSeq* seq = 0;
|
CvSeq* seq = 0;
|
||||||
CvMat **tmp_covs = 0;
|
|
||||||
CvMat **tmp_cov_rotate_mats = 0;
|
|
||||||
|
|
||||||
read_params( fs, node );
|
read_params( fs, node );
|
||||||
|
|
||||||
@ -156,13 +154,10 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
|
|||||||
CV_CALL( inv_eigen_values = (CvMat*)cvReadByName( fs, em_node, "inv_eigen_values" ));
|
CV_CALL( inv_eigen_values = (CvMat*)cvReadByName( fs, em_node, "inv_eigen_values" ));
|
||||||
|
|
||||||
// Size of all the following data
|
// Size of all the following data
|
||||||
data_size = params.nclusters*2*sizeof(CvMat*);
|
data_size = params.nclusters*sizeof(CvMat*);
|
||||||
|
|
||||||
CV_CALL( tmp_covs = (CvMat**)cvAlloc( data_size ));
|
|
||||||
memset( tmp_covs, 0, data_size );
|
|
||||||
|
|
||||||
tmp_cov_rotate_mats = tmp_covs + params.nclusters;
|
|
||||||
|
|
||||||
|
CV_CALL( covs = (CvMat**)cvAlloc( data_size ));
|
||||||
|
memset( covs, 0, data_size );
|
||||||
CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "covs" ));
|
CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "covs" ));
|
||||||
seq = tmp_node->data.seq;
|
seq = tmp_node->data.seq;
|
||||||
if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
|
if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
|
||||||
@ -170,24 +165,23 @@ void CvEM::read( CvFileStorage* fs, CvFileNode* node )
|
|||||||
CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
|
CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
|
||||||
for( int i = 0; i < params.nclusters; i++ )
|
for( int i = 0; i < params.nclusters; i++ )
|
||||||
{
|
{
|
||||||
CV_CALL( tmp_covs[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
|
CV_CALL( covs[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
|
||||||
CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
|
CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
CV_CALL( cov_rotate_mats = (CvMat**)cvAlloc( data_size ));
|
||||||
|
memset( cov_rotate_mats, 0, data_size );
|
||||||
CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "cov_rotate_mats" ));
|
CV_CALL( tmp_node = cvGetFileNodeByName( fs, em_node, "cov_rotate_mats" ));
|
||||||
seq = tmp_node->data.seq;
|
seq = tmp_node->data.seq;
|
||||||
if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
|
if( !CV_NODE_IS_SEQ(tmp_node->tag) || seq->total != params.nclusters)
|
||||||
CV_ERROR( CV_StsParseError, "Missing or invalid sequence of rotated cov. matrices" );
|
CV_ERROR( CV_StsParseError, "Missing or invalid sequence of covariance matrices" );
|
||||||
CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
|
CV_CALL( cvStartReadSeq( seq, &reader, 0 ));
|
||||||
for( int i = 0; i < params.nclusters; i++ )
|
for( int i = 0; i < params.nclusters; i++ )
|
||||||
{
|
{
|
||||||
CV_CALL( tmp_cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
|
CV_CALL( cov_rotate_mats[i] = (CvMat*)cvRead( fs, (CvFileNode*)reader.ptr ));
|
||||||
CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
|
CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
|
||||||
}
|
}
|
||||||
|
|
||||||
covs = tmp_covs;
|
|
||||||
cov_rotate_mats = tmp_cov_rotate_mats;
|
|
||||||
|
|
||||||
ok = true;
|
ok = true;
|
||||||
__END__;
|
__END__;
|
||||||
|
|
||||||
@ -862,10 +856,10 @@ void CvEM::kmeans( const CvVectors& train_data, int nclusters, CvMat* labels,
|
|||||||
{
|
{
|
||||||
int i, nsamples = train_data.count, dims = train_data.dims;
|
int i, nsamples = train_data.count, dims = train_data.dims;
|
||||||
cv::Ptr<CvMat> temp_mat = cvCreateMat(nsamples, dims, CV_32F);
|
cv::Ptr<CvMat> temp_mat = cvCreateMat(nsamples, dims, CV_32F);
|
||||||
|
|
||||||
for( i = 0; i < nsamples; i++ )
|
for( i = 0; i < nsamples; i++ )
|
||||||
memcpy( temp_mat->data.ptr + temp_mat->step*i, train_data.data.fl[i], dims*sizeof(float));
|
memcpy( temp_mat->data.ptr + temp_mat->step*i, train_data.data.fl[i], dims*sizeof(float));
|
||||||
|
|
||||||
cvKMeans2(temp_mat, nclusters, labels, termcrit, 10);
|
cvKMeans2(temp_mat, nclusters, labels, termcrit, 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1240,20 +1234,20 @@ CvEM::CvEM( const Mat& samples, const Mat& sample_idx, CvEMParams params )
|
|||||||
{
|
{
|
||||||
means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
|
means = weights = probs = inv_eigen_values = log_weight_div_det = 0;
|
||||||
covs = cov_rotate_mats = 0;
|
covs = cov_rotate_mats = 0;
|
||||||
|
|
||||||
// just invoke the train() method
|
// just invoke the train() method
|
||||||
train(samples, sample_idx, params);
|
train(samples, sample_idx, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
|
bool CvEM::train( const Mat& _samples, const Mat& _sample_idx,
|
||||||
CvEMParams _params, Mat* _labels )
|
CvEMParams _params, Mat* _labels )
|
||||||
{
|
{
|
||||||
CvMat samples = _samples, sidx = _sample_idx, labels, *plabels = 0;
|
CvMat samples = _samples, sidx = _sample_idx, labels, *plabels = 0;
|
||||||
|
|
||||||
if( _labels )
|
if( _labels )
|
||||||
{
|
{
|
||||||
int nsamples = sidx.data.ptr ? sidx.rows : samples.rows;
|
int nsamples = sidx.data.ptr ? sidx.rows : samples.rows;
|
||||||
|
|
||||||
if( !(_labels->data && _labels->type() == CV_32SC1 &&
|
if( !(_labels->data && _labels->type() == CV_32SC1 &&
|
||||||
(_labels->cols == 1 || _labels->rows == 1) &&
|
(_labels->cols == 1 || _labels->rows == 1) &&
|
||||||
_labels->cols + _labels->rows - 1 == nsamples) )
|
_labels->cols + _labels->rows - 1 == nsamples) )
|
||||||
@ -1267,7 +1261,7 @@ float
|
|||||||
CvEM::predict( const Mat& _sample, Mat* _probs ) const
|
CvEM::predict( const Mat& _sample, Mat* _probs ) const
|
||||||
{
|
{
|
||||||
CvMat sample = _sample, probs, *pprobs = 0;
|
CvMat sample = _sample, probs, *pprobs = 0;
|
||||||
|
|
||||||
if( _probs )
|
if( _probs )
|
||||||
{
|
{
|
||||||
int nclusters = params.nclusters;
|
int nclusters = params.nclusters;
|
||||||
|
@ -332,6 +332,82 @@ void CV_EMTest::run( int /*start_from*/ )
|
|||||||
ts->set_failed_test_info( code );
|
ts->set_failed_test_info( code );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class CV_EMTest_Smoke : public cvtest::BaseTest {
|
||||||
|
public:
|
||||||
|
CV_EMTest_Smoke() {}
|
||||||
|
protected:
|
||||||
|
virtual void run( int /*start_from*/ )
|
||||||
|
{
|
||||||
|
int code = cvtest::TS::OK;
|
||||||
|
CvEM em;
|
||||||
|
|
||||||
|
Mat samples = Mat(3,2,CV_32F);
|
||||||
|
samples.at<float>(0,0) = 1;
|
||||||
|
samples.at<float>(1,0) = 2;
|
||||||
|
samples.at<float>(2,0) = 3;
|
||||||
|
|
||||||
|
CvEMParams params;
|
||||||
|
params.nclusters = 2;
|
||||||
|
|
||||||
|
Mat labels;
|
||||||
|
|
||||||
|
em.train(samples, Mat(), params, &labels);
|
||||||
|
|
||||||
|
Mat firstResult(samples.rows, 1, CV_32FC1);
|
||||||
|
for( int i = 0; i < samples.rows; i++)
|
||||||
|
firstResult.at<float>(i) = em.predict( samples.row(i) );
|
||||||
|
|
||||||
|
// Write out
|
||||||
|
string filename = tempfile() + ".xml";
|
||||||
|
{
|
||||||
|
FileStorage fs = FileStorage(filename, FileStorage::WRITE);
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
em.write(fs.fs, "EM");
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Crash in write method.\n" );
|
||||||
|
ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
em.clear();
|
||||||
|
|
||||||
|
// Read in
|
||||||
|
{
|
||||||
|
FileStorage fs = FileStorage(filename, FileStorage::READ);
|
||||||
|
FileNode fileNode = fs["EM"];
|
||||||
|
|
||||||
|
try
|
||||||
|
{
|
||||||
|
em.read(const_cast<CvFileStorage*>(fileNode.fs), const_cast<CvFileNode*>(fileNode.node));
|
||||||
|
}
|
||||||
|
catch(...)
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Crash in read method.\n" );
|
||||||
|
ts->set_failed_test_info( cvtest::TS::FAIL_EXCEPTION );
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remove( filename.c_str() );
|
||||||
|
|
||||||
|
int errCaseCount = 0;
|
||||||
|
for( int i = 0; i < samples.rows; i++)
|
||||||
|
errCaseCount = std::abs(em.predict(samples.row(i)) - firstResult.at<float>(i)) < FLT_EPSILON ? 0 : 1;
|
||||||
|
|
||||||
|
if( errCaseCount > 0 )
|
||||||
|
{
|
||||||
|
ts->printf( cvtest::TS::LOG, "Different prediction results before writeing and after reading (errCaseCount=%d).\n", errCaseCount );
|
||||||
|
code = cvtest::TS::FAIL_BAD_ACCURACY;
|
||||||
|
}
|
||||||
|
|
||||||
|
ts->set_failed_test_info( code );
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
TEST(ML_KMeans, accuracy) { CV_KMeansTest test; test.safe_run(); }
|
||||||
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
TEST(ML_KNearest, accuracy) { CV_KNearestTest test; test.safe_run(); }
|
||||||
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
TEST(ML_EM, accuracy) { CV_EMTest test; test.safe_run(); }
|
||||||
|
TEST(ML_EM, smoke) { CV_EMTest_Smoke test; test.safe_run(); }
|
||||||
|
Loading…
x
Reference in New Issue
Block a user