Unified approach for backward probability update.

Replacing update_mode_probs() and adapt_probs() with tree_merge_probs().

Change-Id: I50b2c968d67c9265f5216c700cbeba25fb014654
This commit is contained in:
Dmitry Kovalev
2013-11-04 16:12:29 -08:00
parent dde8069e57
commit c622e1d18f
3 changed files with 88 additions and 80 deletions

View File

@@ -350,23 +350,15 @@ void vp9_entropy_mode_init() {
#define COUNT_SAT 20 #define COUNT_SAT 20
#define MAX_UPDATE_FACTOR 128 #define MAX_UPDATE_FACTOR 128
static int update_ct(vp9_prob pre_prob, const unsigned int ct[2]) { static int adapt_prob(vp9_prob pre_prob, const unsigned int ct[2]) {
return merge_probs(pre_prob, ct, COUNT_SAT, MAX_UPDATE_FACTOR); return merge_probs(pre_prob, ct, COUNT_SAT, MAX_UPDATE_FACTOR);
} }
static void update_mode_probs(int n_modes, static void adapt_probs(const vp9_tree_index *tree,
const vp9_tree_index *tree, const vp9_prob *pre_probs, const unsigned int *counts,
const unsigned int *cnt, unsigned int offset, vp9_prob *probs) {
const vp9_prob *pre_probs, vp9_prob *dst_probs, tree_merge_probs(tree, pre_probs, counts, offset,
unsigned int tok0_offset) { COUNT_SAT, MAX_UPDATE_FACTOR, probs);
#define MAX_PROBS 32
unsigned int branch_ct[MAX_PROBS][2];
int t;
assert(n_modes - 1 < MAX_PROBS);
vp9_tree_probs_from_distribution(tree, branch_ct, cnt, tok0_offset);
for (t = 0; t < n_modes - 1; ++t)
dst_probs[t] = update_ct(pre_probs[t], branch_ct[t]);
} }
void vp9_adapt_mode_probs(VP9_COMMON *cm) { void vp9_adapt_mode_probs(VP9_COMMON *cm) {
@@ -376,44 +368,40 @@ void vp9_adapt_mode_probs(VP9_COMMON *cm) {
const FRAME_COUNTS *counts = &cm->counts; const FRAME_COUNTS *counts = &cm->counts;
for (i = 0; i < INTRA_INTER_CONTEXTS; i++) for (i = 0; i < INTRA_INTER_CONTEXTS; i++)
fc->intra_inter_prob[i] = update_ct(pre_fc->intra_inter_prob[i], fc->intra_inter_prob[i] = adapt_prob(pre_fc->intra_inter_prob[i],
counts->intra_inter[i]); counts->intra_inter[i]);
for (i = 0; i < COMP_INTER_CONTEXTS; i++) for (i = 0; i < COMP_INTER_CONTEXTS; i++)
fc->comp_inter_prob[i] = update_ct(pre_fc->comp_inter_prob[i], fc->comp_inter_prob[i] = adapt_prob(pre_fc->comp_inter_prob[i],
counts->comp_inter[i]); counts->comp_inter[i]);
for (i = 0; i < REF_CONTEXTS; i++) for (i = 0; i < REF_CONTEXTS; i++)
fc->comp_ref_prob[i] = update_ct(pre_fc->comp_ref_prob[i], fc->comp_ref_prob[i] = adapt_prob(pre_fc->comp_ref_prob[i],
counts->comp_ref[i]); counts->comp_ref[i]);
for (i = 0; i < REF_CONTEXTS; i++) for (i = 0; i < REF_CONTEXTS; i++)
for (j = 0; j < 2; j++) for (j = 0; j < 2; j++)
fc->single_ref_prob[i][j] = update_ct(pre_fc->single_ref_prob[i][j], fc->single_ref_prob[i][j] = adapt_prob(pre_fc->single_ref_prob[i][j],
counts->single_ref[i][j]); counts->single_ref[i][j]);
for (i = 0; i < INTER_MODE_CONTEXTS; i++) for (i = 0; i < INTER_MODE_CONTEXTS; i++)
update_mode_probs(INTER_MODES, vp9_inter_mode_tree, adapt_probs(vp9_inter_mode_tree, pre_fc->inter_mode_probs[i],
counts->inter_mode[i], pre_fc->inter_mode_probs[i], counts->inter_mode[i], NEARESTMV, fc->inter_mode_probs[i]);
fc->inter_mode_probs[i], NEARESTMV);
for (i = 0; i < BLOCK_SIZE_GROUPS; i++) for (i = 0; i < BLOCK_SIZE_GROUPS; i++)
update_mode_probs(INTRA_MODES, vp9_intra_mode_tree, adapt_probs(vp9_intra_mode_tree, pre_fc->y_mode_prob[i],
counts->y_mode[i], pre_fc->y_mode_prob[i], counts->y_mode[i], 0, fc->y_mode_prob[i]);
fc->y_mode_prob[i], 0);
for (i = 0; i < INTRA_MODES; ++i) for (i = 0; i < INTRA_MODES; ++i)
update_mode_probs(INTRA_MODES, vp9_intra_mode_tree, adapt_probs(vp9_intra_mode_tree, pre_fc->uv_mode_prob[i],
counts->uv_mode[i], pre_fc->uv_mode_prob[i], counts->uv_mode[i], 0, fc->uv_mode_prob[i]);
fc->uv_mode_prob[i], 0);
for (i = 0; i < PARTITION_CONTEXTS; i++) for (i = 0; i < PARTITION_CONTEXTS; i++)
update_mode_probs(PARTITION_TYPES, vp9_partition_tree, counts->partition[i], adapt_probs(vp9_partition_tree, pre_fc->partition_prob[i],
pre_fc->partition_prob[i], fc->partition_prob[i], 0); counts->partition[i], 0, fc->partition_prob[i]);
if (cm->mcomp_filter_type == SWITCHABLE) { if (cm->mcomp_filter_type == SWITCHABLE) {
for (i = 0; i < SWITCHABLE_FILTER_CONTEXTS; i++) for (i = 0; i < SWITCHABLE_FILTER_CONTEXTS; i++)
update_mode_probs(SWITCHABLE_FILTERS, vp9_switchable_interp_tree, adapt_probs(vp9_switchable_interp_tree, pre_fc->switchable_interp_prob[i],
counts->switchable_interp[i], counts->switchable_interp[i], 0,
pre_fc->switchable_interp_prob[i], fc->switchable_interp_prob[i]);
fc->switchable_interp_prob[i], 0);
} }
if (cm->tx_mode == TX_MODE_SELECT) { if (cm->tx_mode == TX_MODE_SELECT) {
@@ -425,23 +413,24 @@ void vp9_adapt_mode_probs(VP9_COMMON *cm) {
for (i = 0; i < TX_SIZE_CONTEXTS; ++i) { for (i = 0; i < TX_SIZE_CONTEXTS; ++i) {
tx_counts_to_branch_counts_8x8(counts->tx.p8x8[i], branch_ct_8x8p); tx_counts_to_branch_counts_8x8(counts->tx.p8x8[i], branch_ct_8x8p);
for (j = 0; j < TX_SIZES - 3; ++j) for (j = 0; j < TX_SIZES - 3; ++j)
fc->tx_probs.p8x8[i][j] = update_ct(pre_fc->tx_probs.p8x8[i][j], fc->tx_probs.p8x8[i][j] = adapt_prob(pre_fc->tx_probs.p8x8[i][j],
branch_ct_8x8p[j]); branch_ct_8x8p[j]);
tx_counts_to_branch_counts_16x16(counts->tx.p16x16[i], branch_ct_16x16p); tx_counts_to_branch_counts_16x16(counts->tx.p16x16[i], branch_ct_16x16p);
for (j = 0; j < TX_SIZES - 2; ++j) for (j = 0; j < TX_SIZES - 2; ++j)
fc->tx_probs.p16x16[i][j] = update_ct(pre_fc->tx_probs.p16x16[i][j], fc->tx_probs.p16x16[i][j] = adapt_prob(pre_fc->tx_probs.p16x16[i][j],
branch_ct_16x16p[j]); branch_ct_16x16p[j]);
tx_counts_to_branch_counts_32x32(counts->tx.p32x32[i], branch_ct_32x32p); tx_counts_to_branch_counts_32x32(counts->tx.p32x32[i], branch_ct_32x32p);
for (j = 0; j < TX_SIZES - 1; ++j) for (j = 0; j < TX_SIZES - 1; ++j)
fc->tx_probs.p32x32[i][j] = update_ct(pre_fc->tx_probs.p32x32[i][j], fc->tx_probs.p32x32[i][j] = adapt_prob(pre_fc->tx_probs.p32x32[i][j],
branch_ct_32x32p[j]); branch_ct_32x32p[j]);
} }
} }
for (i = 0; i < MBSKIP_CONTEXTS; ++i) for (i = 0; i < MBSKIP_CONTEXTS; ++i)
fc->mbskip_probs[i] = update_ct(pre_fc->mbskip_probs[i], counts->mbskip[i]); fc->mbskip_probs[i] = adapt_prob(pre_fc->mbskip_probs[i],
counts->mbskip[i]);
} }
static void set_default_lf_deltas(struct loopfilter *lf) { static void set_default_lf_deltas(struct loopfilter *lf) {

View File

@@ -194,57 +194,44 @@ static vp9_prob adapt_prob(vp9_prob prep, const unsigned int ct[2]) {
return merge_probs(prep, ct, MV_COUNT_SAT, MV_MAX_UPDATE_FACTOR); return merge_probs(prep, ct, MV_COUNT_SAT, MV_MAX_UPDATE_FACTOR);
} }
static unsigned int adapt_probs(unsigned int i, static void adapt_probs(const vp9_tree_index *tree, const vp9_prob *pre_probs,
vp9_tree tree, const unsigned int *counts, vp9_prob *probs) {
vp9_prob this_probs[], tree_merge_probs(tree, pre_probs, counts, 0,
const vp9_prob last_probs[], MV_COUNT_SAT, MV_MAX_UPDATE_FACTOR, probs);
const unsigned int num_events[]) {
const unsigned int left = tree[i] <= 0
? num_events[-tree[i]]
: adapt_probs(tree[i], tree, this_probs, last_probs, num_events);
const unsigned int right = tree[i + 1] <= 0
? num_events[-tree[i + 1]]
: adapt_probs(tree[i + 1], tree, this_probs, last_probs, num_events);
const unsigned int ct[2] = { left, right };
this_probs[i >> 1] = adapt_prob(last_probs[i >> 1], ct);
return left + right;
} }
void vp9_adapt_mv_probs(VP9_COMMON *cm, int allow_hp) { void vp9_adapt_mv_probs(VP9_COMMON *cm, int allow_hp) {
int i, j; int i, j;
const FRAME_CONTEXT *pre_fc = &cm->frame_contexts[cm->frame_context_idx]; nmv_context *fc = &cm->fc.nmvc;
const nmv_context *pre_fc = &cm->frame_contexts[cm->frame_context_idx].nmvc;
const nmv_context_counts *counts = &cm->counts.mv;
nmv_context *ctx = &cm->fc.nmvc; adapt_probs(vp9_mv_joint_tree, pre_fc->joints, counts->joints,
const nmv_context *pre_ctx = &pre_fc->nmvc; fc->joints);
const nmv_context_counts *cts = &cm->counts.mv;
adapt_probs(0, vp9_mv_joint_tree, ctx->joints, pre_ctx->joints, cts->joints);
for (i = 0; i < 2; ++i) { for (i = 0; i < 2; ++i) {
ctx->comps[i].sign = adapt_prob(pre_ctx->comps[i].sign, cts->comps[i].sign); nmv_component *comp = &fc->comps[i];
adapt_probs(0, vp9_mv_class_tree, ctx->comps[i].classes, const nmv_component *pre_comp = &pre_fc->comps[i];
pre_ctx->comps[i].classes, cts->comps[i].classes); const nmv_component_counts *c = &counts->comps[i];
adapt_probs(0, vp9_mv_class0_tree, ctx->comps[i].class0,
pre_ctx->comps[i].class0, cts->comps[i].class0); comp->sign = adapt_prob(pre_comp->sign, c->sign);
adapt_probs(vp9_mv_class_tree, pre_comp->classes, c->classes,
comp->classes);
adapt_probs(vp9_mv_class0_tree, pre_comp->class0, c->class0, comp->class0);
for (j = 0; j < MV_OFFSET_BITS; ++j) for (j = 0; j < MV_OFFSET_BITS; ++j)
ctx->comps[i].bits[j] = adapt_prob(pre_ctx->comps[i].bits[j], comp->bits[j] = adapt_prob(pre_comp->bits[j], c->bits[j]);
cts->comps[i].bits[j]);
for (j = 0; j < CLASS0_SIZE; ++j) for (j = 0; j < CLASS0_SIZE; ++j)
adapt_probs(0, vp9_mv_fp_tree, ctx->comps[i].class0_fp[j], adapt_probs(vp9_mv_fp_tree, pre_comp->class0_fp[j], c->class0_fp[j],
pre_ctx->comps[i].class0_fp[j], cts->comps[i].class0_fp[j]); comp->class0_fp[j]);
adapt_probs(0, vp9_mv_fp_tree, ctx->comps[i].fp, pre_ctx->comps[i].fp, adapt_probs(vp9_mv_fp_tree, pre_comp->fp, c->fp, comp->fp);
cts->comps[i].fp);
if (allow_hp) { if (allow_hp) {
ctx->comps[i].class0_hp = adapt_prob(pre_ctx->comps[i].class0_hp, comp->class0_hp = adapt_prob(pre_comp->class0_hp, c->class0_hp);
cts->comps[i].class0_hp); comp->hp = adapt_prob(pre_comp->hp, c->hp);
ctx->comps[i].hp = adapt_prob(pre_ctx->comps[i].hp, cts->comps[i].hp);
} }
} }
} }

View File

@@ -91,5 +91,37 @@ static INLINE vp9_prob merge_probs(vp9_prob pre_prob,
return weighted_prob(pre_prob, prob, factor); return weighted_prob(pre_prob, prob, factor);
} }
static unsigned int tree_merge_probs_impl(unsigned int i,
const vp9_tree_index *tree,
const vp9_prob *pre_probs,
const unsigned int *counts,
unsigned int count_sat,
unsigned int max_update_factor,
vp9_prob *probs) {
const int l = tree[i];
const unsigned int left_count = (l <= 0)
? counts[-l]
: tree_merge_probs_impl(l, tree, pre_probs, counts,
count_sat, max_update_factor, probs);
const int r = tree[i + 1];
const unsigned int right_count = (r <= 0)
? counts[-r]
: tree_merge_probs_impl(r, tree, pre_probs, counts,
count_sat, max_update_factor, probs);
const unsigned int ct[2] = { left_count, right_count };
probs[i >> 1] = merge_probs(pre_probs[i >> 1], ct,
count_sat, max_update_factor);
return left_count + right_count;
}
static void tree_merge_probs(const vp9_tree_index *tree,
const vp9_prob *pre_probs,
const unsigned int *counts, int offset,
unsigned int count_sat,
unsigned int max_update_factor, vp9_prob *probs) {
tree_merge_probs_impl(0, tree, pre_probs, &counts[-offset],
count_sat, max_update_factor, probs);
}
#endif // VP9_COMMON_VP9_TREECODER_H_ #endif // VP9_COMMON_VP9_TREECODER_H_