vpx/vp8/common/pred_common.c
Ronald S. Bultje 0b8a95a0b2 Rewrite reference frame costing in the RD loop.
I now see I didn't write a very long description, so let's do it
here then. We took a pretty big quality hit (0.1-0.2%) from my
recent fix of the inversion of arguments to vp8_cost_bit() in the
RD reference frame costing. I looked into it and basically the
costing prevented us from switching reference frames. This is of
course silly, since each frame codes its own prob_intra_coded, so
using last frame cost indications as a limiting factor can never
be right.

Here, I've rewritten that code to estimate costings based partially
on statistics from progress on current frame encoding. Overall,
this gives us a ~0.2%-0.3% improvement over what we had previously
before my argument-inversion-fix, and thus about ~0.4% over current
git (on derf-set), and a little more (0.5-1.0%) on HD/STD-HD/YT.

Change-Id: I79ebd4ccec4d6edbf0e152d9590d103ba2747775
2012-05-15 15:32:44 -07:00

365 lines
11 KiB
C

/*
* Copyright (c) 2012 The WebM project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "vp8/common/pred_common.h"
// TBD prediction functions for various bitstream signals
// Returns a context number for the given MB prediction signal
unsigned char get_pred_context( VP8_COMMON *const cm,
MACROBLOCKD *const xd,
PRED_ID pred_id )
{
int pred_context;
MODE_INFO *m = xd->mode_info_context;
// Note:
// The mode info data structure has a one element border above and to the
// left of the entries correpsonding to real macroblocks.
// The prediction flags in these dummy entries are initialised to 0.
switch (pred_id)
{
case PRED_SEG_ID:
pred_context = (m - 1)->mbmi.seg_id_predicted +
(m - cm->mode_info_stride)->mbmi.seg_id_predicted;
break;
case PRED_REF:
pred_context = (m - 1)->mbmi.ref_predicted +
(m - cm->mode_info_stride)->mbmi.ref_predicted;
break;
case PRED_COMP:
// Context based on use of comp pred flag by neighbours
//pred_context =
// ((m - 1)->mbmi.second_ref_frame != INTRA_FRAME) +
// ((m - cm->mode_info_stride)->mbmi.second_ref_frame != INTRA_FRAME);
// Context based on mode and reference frame
//if ( m->mbmi.ref_frame == LAST_FRAME )
// pred_context = 0 + (m->mbmi.mode != ZEROMV);
//else if ( m->mbmi.ref_frame == GOLDEN_FRAME )
// pred_context = 2 + (m->mbmi.mode != ZEROMV);
//else
// pred_context = 4 + (m->mbmi.mode != ZEROMV);
if ( m->mbmi.ref_frame == LAST_FRAME )
pred_context = 0;
else
pred_context = 1;
break;
#if CONFIG_NEWENTROPY
case PRED_MBSKIP:
pred_context = (m - 1)->mbmi.mb_skip_coeff +
(m - cm->mode_info_stride)->mbmi.mb_skip_coeff;
break;
#endif
default:
// TODO *** add error trap code.
pred_context = 0;
break;
}
return pred_context;
}
// This function returns a context probability for coding a given
// prediction signal
vp8_prob get_pred_prob( VP8_COMMON *const cm,
MACROBLOCKD *const xd,
PRED_ID pred_id )
{
vp8_prob pred_probability;
int pred_context;
// Get the appropriate prediction context
pred_context = get_pred_context( cm, xd, pred_id );
switch (pred_id)
{
case PRED_SEG_ID:
pred_probability = cm->segment_pred_probs[pred_context];
break;
case PRED_REF:
pred_probability = cm->ref_pred_probs[pred_context];
break;
case PRED_COMP:
// In keeping with convention elsewhre the probability returned is
// the probability of a "0" outcome which in this case means the
// probability of comp pred off.
pred_probability = cm->prob_comppred[pred_context];
break;
#if CONFIG_NEWENTROPY
case PRED_MBSKIP:
pred_probability = cm->mbskip_pred_probs[pred_context];
break;
#endif
default:
// TODO *** add error trap code.
pred_probability = 128;
break;
}
return pred_probability;
}
// This function returns the status of the given prediction signal.
// I.e. is the predicted value for the given signal correct.
unsigned char get_pred_flag( MACROBLOCKD *const xd,
PRED_ID pred_id )
{
unsigned char pred_flag = 0;
switch (pred_id)
{
case PRED_SEG_ID:
pred_flag = xd->mode_info_context->mbmi.seg_id_predicted;
break;
case PRED_REF:
pred_flag = xd->mode_info_context->mbmi.ref_predicted;
break;
#if CONFIG_NEWENTROPY
case PRED_MBSKIP:
pred_flag = xd->mode_info_context->mbmi.mb_skip_coeff;
break;
#endif
default:
// TODO *** add error trap code.
pred_flag = 0;
break;
}
return pred_flag;
}
// This function sets the status of the given prediction signal.
// I.e. is the predicted value for the given signal correct.
void set_pred_flag( MACROBLOCKD *const xd,
PRED_ID pred_id,
unsigned char pred_flag)
{
switch (pred_id)
{
case PRED_SEG_ID:
xd->mode_info_context->mbmi.seg_id_predicted = pred_flag;
break;
case PRED_REF:
xd->mode_info_context->mbmi.ref_predicted = pred_flag;
break;
#if CONFIG_NEWENTROPY
case PRED_MBSKIP:
xd->mode_info_context->mbmi.mb_skip_coeff = pred_flag;
break;
#endif
default:
// TODO *** add error trap code.
break;
}
}
// The following contain the guts of the prediction code used to
// peredict various bitstream signals.
// Macroblock segment id prediction function
unsigned char get_pred_mb_segid( VP8_COMMON *const cm, int MbIndex )
{
// Currently the prediction for the macroblock segment ID is
// the value stored for this macroblock in the previous frame.
return cm->last_frame_seg_map[MbIndex];
}
MV_REFERENCE_FRAME get_pred_ref( VP8_COMMON *const cm,
MACROBLOCKD *const xd )
{
MODE_INFO *m = xd->mode_info_context;
MV_REFERENCE_FRAME left;
MV_REFERENCE_FRAME above;
MV_REFERENCE_FRAME above_left;
MV_REFERENCE_FRAME pred_ref = LAST_FRAME;
int segment_id = xd->mode_info_context->mbmi.segment_id;
int seg_ref_active;
int i;
unsigned char frame_allowed[MAX_REF_FRAMES] = {1,1,1,1};
unsigned char ref_score[MAX_REF_FRAMES];
unsigned char best_score = 0;
unsigned char left_in_image;
unsigned char above_in_image;
unsigned char above_left_in_image;
// Is segment coding ennabled
seg_ref_active = segfeature_active( xd, segment_id, SEG_LVL_REF_FRAME );
// Special case treatment if segment coding is enabled.
// Dont allow prediction of a reference frame that the segment
// does not allow
if ( seg_ref_active )
{
for ( i = 0; i < MAX_REF_FRAMES; i++ )
{
frame_allowed[i] =
check_segref( xd, segment_id, i );
// Score set to 0 if ref frame not allowed
ref_score[i] = cm->ref_scores[i] * frame_allowed[i];
}
}
else
vpx_memcpy( ref_score, cm->ref_scores, sizeof(ref_score) );
// Reference frames used by neighbours
left = (m - 1)->mbmi.ref_frame;
above = (m - cm->mode_info_stride)->mbmi.ref_frame;
above_left = (m - 1 - cm->mode_info_stride)->mbmi.ref_frame;
// Are neighbours in image
left_in_image = (m - 1)->mbmi.mb_in_image;
above_in_image = (m - cm->mode_info_stride)->mbmi.mb_in_image;
above_left_in_image = (m - 1 - cm->mode_info_stride)->mbmi.mb_in_image;
// Adjust scores for candidate reference frames based on neigbours
if ( frame_allowed[left] && left_in_image )
{
ref_score[left] += 16;
if ( above_left_in_image && (left == above_left) )
ref_score[left] += 4;
}
if ( frame_allowed[above] && above_in_image )
{
ref_score[above] += 16;
if ( above_left_in_image && (above == above_left) )
ref_score[above] += 4;
}
// Now choose the candidate with the highest score
for ( i = 0; i < MAX_REF_FRAMES; i++ )
{
if ( ref_score[i] > best_score )
{
pred_ref = i;
best_score = ref_score[i];
}
}
return pred_ref;
}
// Functions to computes a set of modified reference frame probabilities
// to use when the prediction of the reference frame value fails
void calc_ref_probs( int * count, vp8_prob * probs )
{
int tot_count;
tot_count = count[0] + count[1] + count[2] + count[3];
if ( tot_count )
{
probs[0] = (vp8_prob)((count[0] * 255 + (tot_count >> 1)) / tot_count);
probs[0] += !probs[0];
}
else
probs[0] = 128;
tot_count -= count[0];
if ( tot_count )
{
probs[1] = (vp8_prob)((count[1] * 255 + (tot_count >> 1)) / tot_count);
probs[1] += !probs[1];
}
else
probs[1] = 128;
tot_count -= count[1];
if ( tot_count )
{
probs[2] = (vp8_prob)((count[2] * 255 + (tot_count >> 1)) / tot_count);
probs[2] += !probs[2];
}
else
probs[2] = 128;
}
// Computes a set of modified conditional probabilities for the reference frame
// Values willbe set to 0 for reference frame options that are not possible
// because wither they were predicted and prediction has failed or because
// they are not allowed for a given segment.
void compute_mod_refprobs( VP8_COMMON *const cm )
{
int norm_cnt[MAX_REF_FRAMES];
int intra_count;
int inter_count;
int last_count;
int gfarf_count;
int gf_count;
int arf_count;
intra_count = cm->prob_intra_coded;
inter_count = (255 - intra_count);
last_count = (inter_count * cm->prob_last_coded)/255;
gfarf_count = inter_count - last_count;
gf_count = (gfarf_count * cm->prob_gf_coded)/255;
arf_count = gfarf_count - gf_count;
// Work out modified reference frame probabilities to use where prediction
// of the reference frame fails
norm_cnt[0] = 0;
norm_cnt[1] = last_count;
norm_cnt[2] = gf_count;
norm_cnt[3] = arf_count;
calc_ref_probs( norm_cnt, cm->mod_refprobs[INTRA_FRAME] );
cm->mod_refprobs[INTRA_FRAME][0] = 0; // This branch implicit
norm_cnt[0] = intra_count;
norm_cnt[1] = 0;
norm_cnt[2] = gf_count;
norm_cnt[3] = arf_count;
calc_ref_probs( norm_cnt, cm->mod_refprobs[LAST_FRAME]);
cm->mod_refprobs[LAST_FRAME][1] = 0; // This branch implicit
norm_cnt[0] = intra_count;
norm_cnt[1] = last_count;
norm_cnt[2] = 0;
norm_cnt[3] = arf_count;
calc_ref_probs( norm_cnt, cm->mod_refprobs[GOLDEN_FRAME] );
cm->mod_refprobs[GOLDEN_FRAME][2] = 0; // This branch implicit
norm_cnt[0] = intra_count;
norm_cnt[1] = last_count;
norm_cnt[2] = gf_count;
norm_cnt[3] = 0;
calc_ref_probs( norm_cnt, cm->mod_refprobs[ALTREF_FRAME] );
cm->mod_refprobs[ALTREF_FRAME][2] = 0; // This branch implicit
// Score the reference frames based on overal frequency.
// These scores contribute to the prediction choices.
// Max score 17 min 1
cm->ref_scores[INTRA_FRAME] = 1 + (intra_count * 16 / 255);
cm->ref_scores[LAST_FRAME] = 1 + (last_count * 16 / 255);
cm->ref_scores[GOLDEN_FRAME] = 1 + (gf_count * 16 / 255);
cm->ref_scores[ALTREF_FRAME] = 1 + (arf_count * 16 / 255);
}