diff --git a/vpx_dsp/psnrhvs.c b/vpx_dsp/psnrhvs.c index 7015c0987..095ba5d13 100644 --- a/vpx_dsp/psnrhvs.c +++ b/vpx_dsp/psnrhvs.c @@ -118,7 +118,8 @@ static double convert_score_db(double _score, double _weight, int bit_depth) { static double calc_psnrhvs(const unsigned char *src, int _systride, const unsigned char *dst, int _dystride, double _par, int _w, int _h, int _step, - const double _csf[8][8], uint32_t bit_depth) { + const double _csf[8][8], uint32_t bit_depth, + uint32_t _shift) { double ret; const uint8_t *_src8 = src; const uint8_t *_dst8 = dst; @@ -172,12 +173,12 @@ static double calc_psnrhvs(const unsigned char *src, int _systride, for (i = 0; i < 8; i++) { for (j = 0; j < 8; j++) { int sub = ((i & 12) >> 2) + ((j & 12) >> 1); - if (bit_depth == 8) { + if (bit_depth == 8 && _shift == 0) { dct_s[i * 8 + j] = _src8[(y + i) * _systride + (j + x)]; dct_d[i * 8 + j] = _dst8[(y + i) * _dystride + (j + x)]; } else if (bit_depth == 10 || bit_depth == 12) { - dct_s[i * 8 + j] = _src16[(y + i) * _systride + (j + x)]; - dct_d[i * 8 + j] = _dst16[(y + i) * _dystride + (j + x)]; + dct_s[i * 8 + j] = _src16[(y + i) * _systride + (j + x)] >> _shift; + dct_d[i * 8 + j] = _dst16[(y + i) * _dystride + (j + x)] >> _shift; } s_gmean += dct_s[i * 8 + j]; d_gmean += dct_d[i * 8 + j]; @@ -255,20 +256,27 @@ double vpx_psnrhvs(const YV12_BUFFER_CONFIG *src, double psnrhvs; const double par = 1.0; const int step = 7; + uint32_t bd_shift = 0; vpx_clear_system_state(); assert(bd == 8 || bd == 10 || bd == 12); + assert(bd >= in_bd); + + bd_shift = bd - in_bd; *y_psnrhvs = calc_psnrhvs(src->y_buffer, src->y_stride, dest->y_buffer, dest->y_stride, par, src->y_crop_width, - src->y_crop_height, step, csf_y, bd); + src->y_crop_height, step, csf_y, bd, + bd_shift); *u_psnrhvs = calc_psnrhvs(src->u_buffer, src->uv_stride, dest->u_buffer, dest->uv_stride, par, src->uv_crop_width, - src->uv_crop_height, step, csf_cb420, bd); + src->uv_crop_height, step, csf_cb420, bd, + bd_shift); *v_psnrhvs = calc_psnrhvs(src->v_buffer, src->uv_stride, dest->v_buffer, dest->uv_stride, par, src->uv_crop_width, - src->uv_crop_height, step, csf_cr420, bd); + src->uv_crop_height, step, csf_cr420, bd, + bd_shift); psnrhvs = (*y_psnrhvs) * .8 + .1 * ((*u_psnrhvs) + (*v_psnrhvs)); - return convert_score_db(psnrhvs, 1.0, bd); + return convert_score_db(psnrhvs, 1.0, in_bd); }