bug with negative class labels is fixed
This commit is contained in:
parent
aef71a146e
commit
16f50dbe50
@ -215,6 +215,7 @@ CvGBTrees::train( const CvMat* _train_data, int _tflag,
|
||||
cvCopy( _responses, orig_response);
|
||||
orig_response->step = CV_ELEM_SIZE(_responses->type);
|
||||
|
||||
/*
|
||||
if (!is_regression)
|
||||
{
|
||||
int max_label = -1;
|
||||
@ -231,6 +232,38 @@ CvGBTrees::train( const CvMat* _train_data, int _tflag,
|
||||
if (class_labels->data.i[i])
|
||||
class_labels->data.i[i] = ++class_count;
|
||||
}
|
||||
*/
|
||||
if (!is_regression)
|
||||
{
|
||||
class_count = 0;
|
||||
unsigned char * mask = new unsigned char[get_len(orig_response)];
|
||||
for (int i=0; i<get_len(orig_response); ++i)
|
||||
mask[i] = 0;
|
||||
for (int i=0; i<get_len(orig_response); ++i)
|
||||
if (!mask[i])
|
||||
{
|
||||
class_count++;
|
||||
for (int j=i; j<get_len(orig_response); ++j)
|
||||
if (int(orig_response->data.fl[j]) == int(orig_response->data.fl[i]))
|
||||
mask[j] = 1;
|
||||
}
|
||||
delete[] mask;
|
||||
|
||||
class_labels = cvCreateMat(1, class_count, CV_32S);
|
||||
class_labels->data.i[0] = int(orig_response->data.fl[0]);
|
||||
int j = 1;
|
||||
for (int i=1; i<get_len(orig_response); ++i)
|
||||
{
|
||||
int k = 0;
|
||||
while ((int(orig_response->data.fl[i]) - class_labels->data.i[k]) && (k<j))
|
||||
k++;
|
||||
if (k == j)
|
||||
{
|
||||
class_labels->data.i[k] = int(orig_response->data.fl[i]);
|
||||
j++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data->is_classifier = false;
|
||||
|
||||
@ -443,8 +476,16 @@ void CvGBTrees::find_gradient(const int k)
|
||||
exp_sfi += res;
|
||||
}
|
||||
int orig_label = int(resp_data[idx]);
|
||||
/*
|
||||
grad_data[idx] = (float)(!(k-class_labels->data.i[orig_label]+1)) -
|
||||
(float)(exp_fk / exp_sfi);
|
||||
*/
|
||||
int ensemble_label = 0;
|
||||
while (class_labels->data.i[ensemble_label] - orig_label)
|
||||
ensemble_label++;
|
||||
|
||||
grad_data[idx] = (float)(!(k-ensemble_label)) -
|
||||
(float)(exp_fk / exp_sfi);
|
||||
}
|
||||
}; break;
|
||||
|
||||
@ -772,10 +813,13 @@ float CvGBTrees::predict( const CvMat* _sample, const CvMat* _missing,
|
||||
|
||||
delete[] sum;
|
||||
|
||||
/*
|
||||
int orig_class_label = -1;
|
||||
for (int i=0; i<get_len(class_labels); ++i)
|
||||
if (class_labels->data.i[i] == class_label+1)
|
||||
orig_class_label = i;
|
||||
*/
|
||||
int orig_class_label = class_labels->data.i[class_label];
|
||||
|
||||
return float(orig_class_label);
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user