[*] Fixed #974 ("GPU CascadeClassifier fails with some training files"): Moved IsNodeLeaf bit from NodeDescriptor to FeatureDescriptor for both left and right nodes, therefore from now on max number of rects in a feature is 31

This commit is contained in:
Anton Obukhov
2011-04-04 11:47:21 +00:00
parent 2388fa223e
commit 58476b64a6
3 changed files with 97 additions and 60 deletions

View File

@@ -444,10 +444,22 @@ __global__ void applyHaarClassifierAnchorParallel(Ncv32u *d_IImg, Ncv32u IImgStr
HaarClassifierNodeDescriptor32 nodeLeft = curNode.getLeftNodeDesc();
HaarClassifierNodeDescriptor32 nodeRight = curNode.getRightNodeDesc();
Ncv32f nodeThreshold = curNode.getThreshold();
HaarClassifierNodeDescriptor32 nextNodeDescriptor;
nextNodeDescriptor = (curNodeVal < scaleArea * pixelStdDev * nodeThreshold) ? nodeLeft : nodeRight;
if (nextNodeDescriptor.isLeaf())
HaarClassifierNodeDescriptor32 nextNodeDescriptor;
NcvBool nextNodeIsLeaf;
if (curNodeVal < scaleArea * pixelStdDev * nodeThreshold)
{
nextNodeDescriptor = nodeLeft;
nextNodeIsLeaf = featuresDesc.isLeftNodeLeaf();
}
else
{
nextNodeDescriptor = nodeRight;
nextNodeIsLeaf = featuresDesc.isRightNodeLeaf();
}
if (nextNodeIsLeaf)
{
Ncv32f tmpLeafValue = nextNodeDescriptor.getLeafValue();
curStageSum += tmpLeafValue;
@@ -572,10 +584,22 @@ __global__ void applyHaarClassifierClassifierParallel(Ncv32u *d_IImg, Ncv32u IIm
HaarClassifierNodeDescriptor32 nodeLeft = curNode.getLeftNodeDesc();
HaarClassifierNodeDescriptor32 nodeRight = curNode.getRightNodeDesc();
Ncv32f nodeThreshold = curNode.getThreshold();
HaarClassifierNodeDescriptor32 nextNodeDescriptor;
nextNodeDescriptor = (curNodeVal < scaleArea * pixelStdDev * nodeThreshold) ? nodeLeft : nodeRight;
if (nextNodeDescriptor.isLeaf())
HaarClassifierNodeDescriptor32 nextNodeDescriptor;
NcvBool nextNodeIsLeaf;
if (curNodeVal < scaleArea * pixelStdDev * nodeThreshold)
{
nextNodeDescriptor = nodeLeft;
nextNodeIsLeaf = featuresDesc.isLeftNodeLeaf();
}
else
{
nextNodeDescriptor = nodeRight;
nextNodeIsLeaf = featuresDesc.isRightNodeLeaf();
}
if (nextNodeIsLeaf)
{
Ncv32f tmpLeafValue = nextNodeDescriptor.getLeafValue();
curStageSum += tmpLeafValue;
@@ -2135,8 +2159,9 @@ NCVStatus ncvApplyHaarClassifierCascade_host(NCVMatrix<Ncv32u> &h_integralImage,
while (bMoreNodesToTraverse)
{
HaarClassifierNode128 curNode = h_HaarNodes.ptr()[curNodeOffset];
Ncv32u curNodeFeaturesNum = curNode.getFeatureDesc().getNumFeatures();
Ncv32u curNodeFeaturesOffs = curNode.getFeatureDesc().getFeaturesOffset();
HaarFeatureDescriptor32 curFeatDesc = curNode.getFeatureDesc();
Ncv32u curNodeFeaturesNum = curFeatDesc.getNumFeatures();
Ncv32u curNodeFeaturesOffs = curFeatDesc.getFeaturesOffset();
Ncv32f curNodeVal = 0.f;
for (Ncv32u iRect=0; iRect<curNodeFeaturesNum; iRect++)
@@ -2161,19 +2186,22 @@ NCVStatus ncvApplyHaarClassifierCascade_host(NCVMatrix<Ncv32u> &h_integralImage,
HaarClassifierNodeDescriptor32 nodeLeft = curNode.getLeftNodeDesc();
HaarClassifierNodeDescriptor32 nodeRight = curNode.getRightNodeDesc();
Ncv32f nodeThreshold = curNode.getThreshold();
HaarClassifierNodeDescriptor32 nextNodeDescriptor;
NcvBool nextNodeIsLeaf;
if (curNodeVal < scaleAreaPixels * h_weights.ptr()[i * h_weights.stride() + j] * nodeThreshold)
{
nextNodeDescriptor = nodeLeft;
nextNodeIsLeaf = curFeatDesc.isLeftNodeLeaf();
}
else
{
nextNodeDescriptor = nodeRight;
nextNodeIsLeaf = curFeatDesc.isRightNodeLeaf();
}
NcvBool tmpIsLeaf = nextNodeDescriptor.isLeaf();
if (tmpIsLeaf)
if (nextNodeIsLeaf)
{
Ncv32f tmpLeafValue = nextNodeDescriptor.getLeafValueHost();
curStageSum += tmpLeafValue;

View File

@@ -112,7 +112,9 @@ struct HaarFeatureDescriptor32
private:
#define HaarFeatureDescriptor32_Interpret_MaskFlagTilted 0x80000000
#define HaarFeatureDescriptor32_CreateCheck_MaxNumFeatures 0x7F
#define HaarFeatureDescriptor32_Interpret_MaskFlagLeftNodeLeaf 0x40000000
#define HaarFeatureDescriptor32_Interpret_MaskFlagRightNodeLeaf 0x20000000
#define HaarFeatureDescriptor32_CreateCheck_MaxNumFeatures 0x1F
#define HaarFeatureDescriptor32_NumFeatures_Shift 24
#define HaarFeatureDescriptor32_CreateCheck_MaxFeatureOffset 0x00FFFFFF
@@ -120,7 +122,8 @@ private:
public:
__host__ NCVStatus create(NcvBool bTilted, Ncv32u numFeatures, Ncv32u offsetFeatures)
__host__ NCVStatus create(NcvBool bTilted, NcvBool bLeftLeaf, NcvBool bRightLeaf,
Ncv32u numFeatures, Ncv32u offsetFeatures)
{
if (numFeatures > HaarFeatureDescriptor32_CreateCheck_MaxNumFeatures)
{
@@ -132,6 +135,8 @@ public:
}
this->desc = 0;
this->desc |= (bTilted ? HaarFeatureDescriptor32_Interpret_MaskFlagTilted : 0);
this->desc |= (bLeftLeaf ? HaarFeatureDescriptor32_Interpret_MaskFlagLeftNodeLeaf : 0);
this->desc |= (bRightLeaf ? HaarFeatureDescriptor32_Interpret_MaskFlagRightNodeLeaf : 0);
this->desc |= (numFeatures << HaarFeatureDescriptor32_NumFeatures_Shift);
this->desc |= offsetFeatures;
return NCV_SUCCESS;
@@ -142,9 +147,19 @@ public:
return (this->desc & HaarFeatureDescriptor32_Interpret_MaskFlagTilted) != 0;
}
__device__ __host__ NcvBool isLeftNodeLeaf(void)
{
return (this->desc & HaarFeatureDescriptor32_Interpret_MaskFlagLeftNodeLeaf) != 0;
}
__device__ __host__ NcvBool isRightNodeLeaf(void)
{
return (this->desc & HaarFeatureDescriptor32_Interpret_MaskFlagRightNodeLeaf) != 0;
}
__device__ __host__ Ncv32u getNumFeatures(void)
{
return (this->desc & ~HaarFeatureDescriptor32_Interpret_MaskFlagTilted) >> HaarFeatureDescriptor32_NumFeatures_Shift;
return (this->desc >> HaarFeatureDescriptor32_NumFeatures_Shift) & HaarFeatureDescriptor32_CreateCheck_MaxNumFeatures;
}
__device__ __host__ Ncv32u getFeaturesOffset(void)
@@ -158,34 +173,18 @@ struct HaarClassifierNodeDescriptor32
{
uint1 _ui1;
#define HaarClassifierNodeDescriptor32_Interpret_MaskSwitch (1 << 30)
__host__ NCVStatus create(Ncv32f leafValue)
{
if ((*(Ncv32u *)&leafValue) & HaarClassifierNodeDescriptor32_Interpret_MaskSwitch)
{
return NCV_HAAR_XML_LOADING_EXCEPTION;
}
*(Ncv32f *)&this->_ui1 = leafValue;
return NCV_SUCCESS;
}
__host__ NCVStatus create(Ncv32u offsetHaarClassifierNode)
{
if (offsetHaarClassifierNode >= HaarClassifierNodeDescriptor32_Interpret_MaskSwitch)
{
return NCV_HAAR_XML_LOADING_EXCEPTION;
}
this->_ui1.x = offsetHaarClassifierNode;
this->_ui1.x |= HaarClassifierNodeDescriptor32_Interpret_MaskSwitch;
return NCV_SUCCESS;
}
__device__ __host__ NcvBool isLeaf(void)
{
return !(this->_ui1.x & HaarClassifierNodeDescriptor32_Interpret_MaskSwitch);
}
__host__ Ncv32f getLeafValueHost(void)
{
return *(Ncv32f *)&this->_ui1.x;
@@ -200,7 +199,7 @@ struct HaarClassifierNodeDescriptor32
__device__ __host__ Ncv32u getNextNodeOffset(void)
{
return (this->_ui1.x & ~HaarClassifierNodeDescriptor32_Interpret_MaskSwitch);
return this->_ui1.x;
}
};