Adds forward prob updates for supertx expt

Also refactors the supertx prob models, and includes other
cleanups.

Change-Id: I74de6c01d872ae09bf6d43a31f53d43283b6b226
This commit is contained in:
Deb Mukherjee 2014-12-12 14:07:02 -08:00
parent 5de9280ae9
commit 7ec9792f96
12 changed files with 178 additions and 64 deletions

View File

@ -268,11 +268,13 @@ static INLINE BLOCK_SIZE get_subsize(BLOCK_SIZE bsize,
extern const TX_TYPE intra_mode_to_tx_type_lookup[INTRA_MODES];
#if CONFIG_SUPERTX
#define PARTITION_SUPERTX_CONTEXTS 2
#if CONFIG_TX64X64
#define MAX_SUPERTX_BLOCK_SIZE BLOCK_64X64
#else
#define MAX_SUPERTX_BLOCK_SIZE BLOCK_32X32
#endif
#endif // CONFIG_TX64X64
static INLINE TX_SIZE bsize_to_tx_size(BLOCK_SIZE bsize) {
const TX_SIZE bsize_to_tx_size_lookup[BLOCK_SIZES] = {
@ -308,7 +310,7 @@ static TX_TYPE ext_tx_to_txtype[EXT_TX_TYPES] = {
FLIPADST_DCT,
DCT_FLIPADST,
};
#endif
#endif // CONFIG_EXT_TX
static INLINE TX_TYPE get_tx_type(PLANE_TYPE plane_type,
const MACROBLOCKD *xd) {

View File

@ -184,4 +184,8 @@ const TX_SIZE uvsupertx_size_lookup[TX_SIZES][2][2] = {
{{TX_64X64, TX_32X32}, {TX_32X32, TX_32X32}},
#endif // CONFIG_TX64X64
};
const int partition_supertx_context_lookup[PARTITION_TYPES] = {
-1, 0, 0, 1
};
#endif

View File

@ -34,6 +34,7 @@ extern const TX_SIZE tx_mode_to_biggest_tx_size[TX_MODES];
extern const BLOCK_SIZE ss_size_lookup[BLOCK_SIZES][2][2];
#if CONFIG_SUPERTX
extern const TX_SIZE uvsupertx_size_lookup[TX_SIZES][2][2];
extern const int partition_supertx_context_lookup[PARTITION_TYPES];
#endif
#ifdef __cplusplus

View File

@ -22,6 +22,7 @@ extern "C" {
#endif
#define DIFF_UPDATE_PROB 252
#define GROUP_DIFF_UPDATE_PROB 252
// Coefficient token alphabet
#define ZERO_TOKEN 0 // 0 Extra Bits 0+0

View File

@ -314,20 +314,17 @@ static const vp9_prob default_ext_tx_prob[3][EXT_TX_TYPES - 1] = {
#endif // CONFIG_EXT_TX
#if CONFIG_SUPERTX
static const vp9_prob default_supertx_prob[TX_SIZES] = {
255, 160, 160, 160,
static const vp9_prob default_supertx_prob[PARTITION_SUPERTX_CONTEXTS]
[TX_SIZES] = {
#if CONFIG_TX64X64
160
{ 1, 160, 160, 170, 170 },
{ 1, 200, 200, 210, 210 },
#else
{ 1, 160, 160, 170 },
{ 1, 200, 200, 210 },
#endif
};
static const vp9_prob default_supertxsplit_prob[TX_SIZES] = {
255, 200, 200, 200,
#if CONFIG_TX64X64
200
#endif
};
#endif
#endif // CONFIG_SUPERTX
#if CONFIG_TX64X64
void tx_counts_to_branch_counts_64x64(const unsigned int *tx_count_64x64p,
@ -408,7 +405,6 @@ void vp9_init_mode_probs(FRAME_CONTEXT *fc) {
#endif
#if CONFIG_SUPERTX
vp9_copy(fc->supertx_prob, default_supertx_prob);
vp9_copy(fc->supertxsplit_prob, default_supertxsplit_prob);
#endif
#if CONFIG_TX_SKIP
vp9_copy(fc->y_tx_skip_prob, default_y_tx_skip_prob);
@ -529,13 +525,11 @@ void vp9_adapt_mode_probs(VP9_COMMON *cm) {
#endif // CONFIG_EXT_TX
#if CONFIG_SUPERTX
for (i = 1; i < TX_SIZES; ++i) {
fc->supertx_prob[i] = adapt_prob(pre_fc->supertx_prob[i],
counts->supertx[i]);
}
for (i = 1; i < TX_SIZES; ++i) {
fc->supertxsplit_prob[i] = adapt_prob(pre_fc->supertxsplit_prob[i],
counts->supertxsplit[i]);
for (i = 0; i < PARTITION_SUPERTX_CONTEXTS; ++i) {
for (j = 1; j < TX_SIZES; ++j) {
fc->supertx_prob[i][j] = adapt_prob(pre_fc->supertx_prob[i][j],
counts->supertx[i][j]);
}
}
#endif // CONFIG_SUPERTX
#if CONFIG_TX_SKIP

View File

@ -63,8 +63,7 @@ typedef struct frame_contexts {
vp9_prob ext_tx_prob[3][EXT_TX_TYPES - 1];
#endif
#if CONFIG_SUPERTX
vp9_prob supertx_prob[TX_SIZES];
vp9_prob supertxsplit_prob[TX_SIZES];
vp9_prob supertx_prob[PARTITION_SUPERTX_CONTEXTS][TX_SIZES];
#endif
#if CONFIG_TX_SKIP
vp9_prob y_tx_skip_prob[2];
@ -96,8 +95,7 @@ typedef struct {
unsigned int ext_tx[3][EXT_TX_TYPES];
#endif
#if CONFIG_SUPERTX
unsigned int supertx[TX_SIZES][2];
unsigned int supertxsplit[TX_SIZES][2];
unsigned int supertx[PARTITION_SUPERTX_CONTEXTS][TX_SIZES][2];
unsigned int supertx_size[BLOCK_SIZES];
#endif
#if CONFIG_TX_SKIP

View File

@ -999,13 +999,11 @@ static void decode_partition(VP9_COMMON *const cm, MACROBLOCKD *const xd,
partition != PARTITION_NONE &&
bsize <= MAX_SUPERTX_BLOCK_SIZE &&
!supertx_enabled) {
if (partition == PARTITION_SPLIT) {
supertx_enabled = vp9_read(r, cm->fc.supertxsplit_prob[supertx_size]);
cm->counts.supertxsplit[supertx_size][supertx_enabled]++;
} else {
supertx_enabled = vp9_read(r, cm->fc.supertx_prob[supertx_size]);
cm->counts.supertx[supertx_size][supertx_enabled]++;
}
const int supertx_context =
partition_supertx_context_lookup[partition];
supertx_enabled = vp9_read(
r, cm->fc.supertx_prob[supertx_context][supertx_size]);
cm->counts.supertx[supertx_context][supertx_size][supertx_enabled]++;
}
if (supertx_enabled && read_token) {
int offset = mi_row * cm->mi_stride + mi_col;
@ -2020,6 +2018,30 @@ static size_t read_uncompressed_header(VP9Decoder *pbi,
return sz;
}
#if CONFIG_EXT_TX
static void read_ext_tx_probs(FRAME_CONTEXT *fc, vp9_reader *r) {
int i, j;
if (vp9_read(r, GROUP_DIFF_UPDATE_PROB)) {
for (j = TX_4X4; j <= TX_16X16; ++j)
for (i = 0; i < EXT_TX_TYPES - 1; ++i)
vp9_diff_update_prob(r, &fc->ext_tx_prob[j][i]);
}
}
#endif // CONFIG_EXT_TX
#if CONFIG_SUPERTX
static void read_supertx_probs(FRAME_CONTEXT *fc, vp9_reader *r) {
int i, j;
if (vp9_read(r, GROUP_DIFF_UPDATE_PROB)) {
for (i = 0; i < PARTITION_SUPERTX_CONTEXTS; ++i) {
for (j = 1; j < TX_SIZES; ++j) {
vp9_diff_update_prob(r, &fc->supertx_prob[i][j]);
}
}
}
}
#endif // CONFIG_SUPERTX
static int read_compressed_header(VP9Decoder *pbi, const uint8_t *data,
size_t partition_size) {
VP9_COMMON *const cm = &pbi->common;
@ -2068,9 +2090,10 @@ static int read_compressed_header(VP9Decoder *pbi, const uint8_t *data,
read_mv_probs(nmvc, cm->allow_high_precision_mv, &r);
#if CONFIG_EXT_TX
for (j = TX_4X4; j <= TX_16X16; ++j)
for (i = 0; i < EXT_TX_TYPES - 1; ++i)
vp9_diff_update_prob(&r, &fc->ext_tx_prob[j][i]);
read_ext_tx_probs(fc, &r);
#endif
#if CONFIG_SUPERTX
read_supertx_probs(fc, &r);
#endif
#if CONFIG_TX_SKIP
for (i = 0; i < 2; i++)

View File

@ -704,9 +704,6 @@ static void read_inter_frame_mode_info(VP9_COMMON *const cm,
mbmi->tx_size <= TX_16X16 &&
cm->base_qindex > 0 &&
mbmi->sb_type >= BLOCK_8X8 &&
#if CONFIG_SUPERTX
!supertx_enabled &&
#endif
!vp9_segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP) &&
!mbmi->skip) {
mbmi->ext_txfrm = vp9_read_tree(r, vp9_ext_tx_tree,

View File

@ -96,6 +96,24 @@ static void prob_diff_update(const vp9_tree_index *tree,
vp9_cond_prob_diff_update(w, &probs[i], branch_ct[i]);
}
static int prob_diff_update_savings(const vp9_tree_index *tree,
vp9_prob probs[/*n - 1*/],
const unsigned int counts[/*n - 1*/],
int n) {
int i;
unsigned int branch_ct[32][2];
int savings = 0;
// Assuming max number of probabilities <= 32
assert(n <= 32);
vp9_tree_probs_from_distribution(tree, branch_ct, counts);
for (i = 0; i < n - 1; ++i) {
savings += vp9_cond_prob_diff_update_savings(&probs[i], branch_ct[i]);
}
return savings;
}
static void write_selected_tx_size(const VP9_COMMON *cm,
const MACROBLOCKD *xd,
TX_SIZE tx_size, BLOCK_SIZE bsize,
@ -129,7 +147,6 @@ static int write_skip(const VP9_COMMON *cm, const MACROBLOCKD *xd,
static void update_skip_probs(VP9_COMMON *cm, vp9_writer *w) {
int k;
for (k = 0; k < SKIP_CONTEXTS; ++k)
vp9_cond_prob_diff_update(w, &cm->fc.skip_probs[k], cm->counts.skip[k]);
}
@ -142,6 +159,55 @@ static void update_switchable_interp_probs(VP9_COMMON *cm, vp9_writer *w) {
cm->counts.switchable_interp[j], SWITCHABLE_FILTERS, w);
}
#if CONFIG_EXT_TX
static void update_ext_tx_probs(VP9_COMMON *cm, vp9_writer *w) {
const int savings_thresh = vp9_cost_one(GROUP_DIFF_UPDATE_PROB) -
vp9_cost_zero(GROUP_DIFF_UPDATE_PROB);
int i;
int savings = 0;
int do_update = 0;
for (i = TX_4X4; i <= TX_16X16; ++i) {
savings += prob_diff_update_savings(vp9_ext_tx_tree, cm->fc.ext_tx_prob[i],
cm->counts.ext_tx[i], EXT_TX_TYPES);
}
do_update = savings > savings_thresh;
vp9_write(w, do_update, GROUP_DIFF_UPDATE_PROB);
if (do_update) {
for (i = TX_4X4; i <= TX_16X16; ++i) {
prob_diff_update(vp9_ext_tx_tree, cm->fc.ext_tx_prob[i],
cm->counts.ext_tx[i], EXT_TX_TYPES, w);
}
}
}
#endif // CONFIG_EXT_TX
#if CONFIG_SUPERTX
static void update_supertx_probs(VP9_COMMON *cm, vp9_writer *w) {
const int savings_thresh = vp9_cost_one(GROUP_DIFF_UPDATE_PROB) -
vp9_cost_zero(GROUP_DIFF_UPDATE_PROB);
int i, j;
int savings = 0;
int do_update = 0;
for (i = 0; i < PARTITION_SUPERTX_CONTEXTS; ++i) {
for (j = 1; j < TX_SIZES; ++j) {
savings +=
vp9_cond_prob_diff_update_savings(&cm->fc.supertx_prob[i][j],
cm->counts.supertx[i][j]);
}
}
do_update = savings > savings_thresh;
vp9_write(w, do_update, GROUP_DIFF_UPDATE_PROB);
if (do_update) {
for (i = 0; i < PARTITION_SUPERTX_CONTEXTS; ++i) {
for (j = 1; j < TX_SIZES; ++j) {
vp9_cond_prob_diff_update(w, &cm->fc.supertx_prob[i][j],
cm->counts.supertx[i][j]);
}
}
}
}
#endif // CONFIG_SUPERTX
static void pack_mb_tokens(vp9_writer *w,
TOKENEXTRA **tp, const TOKENEXTRA *const stop,
vpx_bit_depth_t bit_depth) {
@ -588,9 +654,9 @@ static void write_modes_sb(VP9_COMP *cpi,
if (!supertx_enabled && cm->frame_type != KEY_FRAME &&
partition != PARTITION_NONE && bsize <= MAX_SUPERTX_BLOCK_SIZE) {
TX_SIZE supertx_size = bsize_to_tx_size(bsize);
vp9_prob prob = partition == PARTITION_SPLIT ?
cm->fc.supertxsplit_prob[supertx_size] :
cm->fc.supertx_prob[supertx_size];
vp9_prob prob =
cm->fc.supertx_prob[partition_supertx_context_lookup[partition]]
[supertx_size];
supertx_enabled = (xd->mi[0].mbmi.tx_size == supertx_size);
vp9_write(w, supertx_enabled, prob);
if (supertx_enabled) {
@ -1456,10 +1522,10 @@ static size_t write_compressed_header(VP9_COMP *cpi, uint8_t *data) {
vp9_write_nmv_probs(cm, cm->allow_high_precision_mv, &header_bc);
#if CONFIG_EXT_TX
for (i = TX_4X4; i <= TX_16X16; ++i) {
prob_diff_update(vp9_ext_tx_tree, cm->fc.ext_tx_prob[i],
cm->counts.ext_tx[i], EXT_TX_TYPES, &header_bc);
}
update_ext_tx_probs(cm, &header_bc);
#endif
#if CONFIG_SUPERTX
update_supertx_probs(cm, &header_bc);
#endif
#if CONFIG_TX_SKIP
for (i = 0; i < 2; i++)

View File

@ -1491,10 +1491,8 @@ static void encode_sb(VP9_COMP *cpi, const TileInfo *const tile,
xd->mi[0].mbmi.skip;
}
}
if (partition != PARTITION_SPLIT)
cm->counts.supertx[supertx_size][1]++;
else
cm->counts.supertxsplit[supertx_size][1]++;
cm->counts.supertx
[partition_supertx_context_lookup[partition]][supertx_size][1]++;
cm->counts.supertx_size[supertx_size]++;
#if CONFIG_EXT_TX
if (supertx_size < TX_32X32 && !xd->mi[0].mbmi.skip)
@ -1508,10 +1506,8 @@ static void encode_sb(VP9_COMP *cpi, const TileInfo *const tile,
return;
} else {
if (output_enabled) {
if (partition != PARTITION_SPLIT)
cm->counts.supertx[supertx_size][0]++;
else
cm->counts.supertxsplit[supertx_size][0]++;
cm->counts.supertx
[partition_supertx_context_lookup[partition]][supertx_size][0]++;
}
}
}
@ -2870,7 +2866,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
best_partition = pc_tree->partitioning;
pc_tree->partitioning = PARTITION_SPLIT;
sum_rdc.rate += vp9_cost_bit(cm->fc.supertxsplit_prob[supertx_size], 0);
sum_rdc.rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_SPLIT]][supertx_size],
0);
sum_rdc.rdcost =
RDCOST(x->rdmult, x->rddiv, sum_rdc.rate, sum_rdc.dist);
@ -2888,7 +2887,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
#endif
pc_tree);
tmp_rate += vp9_cost_bit(cm->fc.supertxsplit_prob[supertx_size], 1);
tmp_rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_SPLIT]][supertx_size],
1);
tmp_rd = RDCOST(x->rdmult, x->rddiv, tmp_rate, tmp_dist);
if (tmp_rd < sum_rdc.rdcost) {
sum_rdc.rdcost = tmp_rd;
@ -2952,7 +2954,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
best_partition = pc_tree->partitioning;
pc_tree->partitioning = PARTITION_SPLIT;
sum_rdc.rate += vp9_cost_bit(cm->fc.supertxsplit_prob[supertx_size], 0);
sum_rdc.rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_SPLIT]][supertx_size],
0);
sum_rdc.rdcost =
RDCOST(x->rdmult, x->rddiv, sum_rdc.rate, sum_rdc.dist);
@ -2970,7 +2975,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
#endif
pc_tree);
tmp_rate += vp9_cost_bit(cm->fc.supertxsplit_prob[supertx_size], 1);
tmp_rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_SPLIT]][supertx_size],
1);
tmp_rd = RDCOST(x->rdmult, x->rddiv, tmp_rate, tmp_dist);
if (tmp_rd < sum_rdc.rdcost) {
sum_rdc.rdcost = tmp_rd;
@ -3076,7 +3084,9 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
best_partition = pc_tree->partitioning;
pc_tree->partitioning = PARTITION_HORZ;
sum_rdc.rate += vp9_cost_bit(cm->fc.supertx_prob[supertx_size], 0);
sum_rdc.rate += vp9_cost_bit(
cm->fc.supertx_prob[partition_supertx_context_lookup[PARTITION_HORZ]]
[supertx_size], 0);
sum_rdc.rdcost = RDCOST(x->rdmult, x->rddiv, sum_rdc.rate, sum_rdc.dist);
if (!check_intra_sb(cpi, tile, mi_row, mi_col, bsize, pc_tree)) {
@ -3093,7 +3103,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
#endif
pc_tree);
tmp_rate += vp9_cost_bit(cm->fc.supertx_prob[supertx_size], 1);
tmp_rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_HORZ]][supertx_size],
1);
tmp_rd = RDCOST(x->rdmult, x->rddiv, tmp_rate, tmp_dist);
if (tmp_rd < sum_rdc.rdcost) {
sum_rdc.rdcost = tmp_rd;
@ -3187,8 +3200,9 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
TX_SIZE supertx_size = bsize_to_tx_size(bsize);
best_partition = pc_tree->partitioning;
pc_tree->partitioning = PARTITION_VERT;
sum_rdc.rate += vp9_cost_bit(cm->fc.supertx_prob[supertx_size], 0);
sum_rdc.rate += vp9_cost_bit(
cm->fc.supertx_prob[partition_supertx_context_lookup[PARTITION_VERT]]
[supertx_size], 0);
sum_rdc.rdcost = RDCOST(x->rdmult, x->rddiv, sum_rdc.rate, sum_rdc.dist);
if (!check_intra_sb(cpi, tile, mi_row, mi_col, bsize, pc_tree)) {
@ -3205,7 +3219,10 @@ static void rd_pick_partition(VP9_COMP *cpi, const TileInfo *const tile,
#endif
pc_tree);
tmp_rate += vp9_cost_bit(cm->fc.supertx_prob[supertx_size], 1);
tmp_rate += vp9_cost_bit(
cm->fc.supertx_prob
[partition_supertx_context_lookup[PARTITION_VERT]][supertx_size],
1);
tmp_rd = RDCOST(x->rdmult, x->rddiv, tmp_rate, tmp_dist);
if (tmp_rd < sum_rdc.rdcost) {
sum_rdc.rdcost = tmp_rd;

View File

@ -190,3 +190,12 @@ void vp9_cond_prob_diff_update(vp9_writer *w, vp9_prob *oldp,
vp9_write(w, 0, upd);
}
}
int vp9_cond_prob_diff_update_savings(vp9_prob *oldp,
const unsigned int ct[2]) {
const vp9_prob upd = DIFF_UPDATE_PROB;
vp9_prob newp = get_binary_prob(ct[0], ct[1]);
const int savings = vp9_prob_diff_update_savings_search(ct, *oldp, &newp,
upd);
return savings;
}

View File

@ -22,11 +22,13 @@ void vp9_write_prob_diff_update(vp9_writer *w,
void vp9_cond_prob_diff_update(vp9_writer *w, vp9_prob *oldp,
unsigned int *ct);
int vp9_cond_prob_diff_update_savings(vp9_prob *oldp,
const unsigned int ct[2]);
int vp9_prob_diff_update_savings_search(const unsigned int *ct,
vp9_prob oldp, vp9_prob *bestp,
vp9_prob upd);
int vp9_prob_diff_update_savings_search_model(const unsigned int *ct,
const vp9_prob *oldp,
vp9_prob *bestp,