speech-tools/grammar/scfg/EST_SCFG_inout.cc
2015-09-19 10:52:26 +02:00

599 lines
15 KiB
C++

/*************************************************************************/
/* */
/* Centre for Speech Technology Research */
/* University of Edinburgh, UK */
/* Copyright (c) 1997 */
/* All Rights Reserved. */
/* */
/* Permission is hereby granted, free of charge, to use and distribute */
/* this software and its documentation without restriction, including */
/* without limitation the rights to use, copy, modify, merge, publish, */
/* distribute, sublicense, and/or sell copies of this work, and to */
/* permit persons to whom this work is furnished to do so, subject to */
/* the following conditions: */
/* 1. The code must retain the above copyright notice, this list of */
/* conditions and the following disclaimer. */
/* 2. Any modifications must be clearly marked as such. */
/* 3. Original authors' names are not deleted. */
/* 4. The authors' names are not used to endorse or promote products */
/* derived from this software without specific prior written */
/* permission. */
/* */
/* THE UNIVERSITY OF EDINBURGH AND THE CONTRIBUTORS TO THIS WORK */
/* DISCLAIM ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING */
/* ALL IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS, IN NO EVENT */
/* SHALL THE UNIVERSITY OF EDINBURGH NOR THE CONTRIBUTORS BE LIABLE */
/* FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES */
/* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN */
/* AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, */
/* ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF */
/* THIS SOFTWARE. */
/* */
/*************************************************************************/
/* Author : Alan W Black */
/* Date : October 1997 */
/*-----------------------------------------------------------------------*/
/* */
/* Implementation of an inside-outside reestimation procedure for */
/* building a stochastic CFG seeded with a bracket corpus. */
/* Based on "Inside-Outside Reestimation from partially bracketed */
/* corpora", F Pereira and Y. Schabes. pp 128-135, 30th ACL, Newark, */
/* Delaware 1992. */
/* */
/* This should really be done in the log domain. Addition in the log */
/* domain can be done with a formula in Huang, Ariki and Jack */
/* (log(a)-log(b)) */
/* log(a+b) = log(1 + e ) + log(b) */
/* */
/*=======================================================================*/
#include <cstdlib>
#include "EST_SCFG_Chart.h"
#include "EST_simplestats.h"
#include "EST_math.h"
#include "EST_TVector.h"
static const EST_bracketed_string def_val_s;
static EST_bracketed_string error_return_s;
template <> const EST_bracketed_string *EST_TVector<EST_bracketed_string>::def_val=&def_val_s;
template <> EST_bracketed_string *EST_TVector<EST_bracketed_string>::error_return=&error_return_s;
#if defined(INSTANTIATE_TEMPLATES)
#include "../base_class/EST_TVector.cc"
template class EST_TVector<EST_bracketed_string>;
#endif
void set_corpus(EST_Bcorpus &b, LISP examples)
{
LISP e;
int i;
b.resize(siod_llength(examples));
for (i=0,e=examples; e != NIL; e=cdr(e),i++)
b.a_no_check(i).set_bracketed_string(car(e));
}
void EST_bracketed_string::init()
{
bs = NIL;
gc_protect(&bs);
symbols = 0;
valid_spans = 0;
p_length = 0;
}
EST_bracketed_string::EST_bracketed_string()
{
init();
}
EST_bracketed_string::EST_bracketed_string(LISP string)
{
init();
set_bracketed_string(string);
}
EST_bracketed_string::~EST_bracketed_string()
{
int i;
bs=NIL;
gc_unprotect(&bs);
delete [] symbols;
for (i=0; i < p_length; i++)
delete [] valid_spans[i];
delete [] valid_spans;
}
void EST_bracketed_string::set_bracketed_string(LISP string)
{
bs=NIL;
delete [] symbols;
p_length = find_num_nodes(string);
symbols = new LISP[p_length];
set_leaf_indices(string,0,symbols);
bs = string;
int i,j;
valid_spans = new int*[length()];
for (i=0; i < length(); i++)
{
valid_spans[i] = new int[length()+1];
for (j=i+1; j <= length(); j++)
valid_spans[i][j] = 0;
}
// fill in valid table
if (p_length > 0)
find_valid(0,bs);
}
int EST_bracketed_string::find_num_nodes(LISP string)
{
// This wont could nil as an atom
if (string == NIL)
return 0;
else if (CONSP(string))
return find_num_nodes(car(string))+
find_num_nodes(cdr(string));
else
return 1;
}
int EST_bracketed_string::set_leaf_indices(LISP string,int i,LISP *syms)
{
if (string == NIL)
return i;
else if (!CONSP(car(string)))
{
syms[i] = string;
return set_leaf_indices(cdr(string),i+1,syms);
}
else // car is a tree
{
return set_leaf_indices(cdr(string),
set_leaf_indices(car(string),i,syms),
syms);
}
}
void EST_bracketed_string::find_valid(int s,LISP t) const
{
LISP l;
int c;
if (consp(t))
{
for (c=s,l=t; l != NIL; l=cdr(l))
{
c += num_leafs(car(l));
valid_spans[s][c] = 1;
}
find_valid(s,car(t));
find_valid(s+num_leafs(car(t)),cdr(t));
}
}
int EST_bracketed_string::num_leafs(LISP t) const
{
if (t == NIL)
return 0;
else if (!consp(t))
return 1;
else
return num_leafs(car(t)) + num_leafs(cdr(t));
}
EST_SCFG_traintest::EST_SCFG_traintest(void) : EST_SCFG()
{
inside = 0;
outside = 0;
n.resize(0);
d.resize(0);
}
EST_SCFG_traintest::~EST_SCFG_traintest(void)
{
}
void EST_SCFG_traintest::load_corpus(const EST_String &filename)
{
set_corpus(corpus,vload(filename,1));
}
// From the formula in the paper
double EST_SCFG_traintest::f_I_cal(int c, int p, int i, int k)
{
// Find Inside probability
double res;
if (i == k-1)
{
res = prob_U(p,terminal(corpus.a_no_check(c).symbol_at(i)));
// printf("prob_U p %s (%d) %d m %s (%d) res %g\n",
// (const char *)nonterminal(p),p,
// i,
// (const char *)corpus.a_no_check(c).symbol_at(i),
// terminal(corpus.a_no_check(c).symbol_at(i)),
// res);
}
else if (corpus.a_no_check(c).valid(i,k) == TRUE)
{
int j;
double s=0;
int q,r;
for (q = 0; q < num_nonterminals(); q++)
for (r = 0; r < num_nonterminals(); r++)
{
double pBpqr = prob_B(p,q,r);
if (pBpqr > 0)
for (j=i+1; j < k; j++)
{
double in = f_I(c,q,i,j);
if (in > 0)
s += pBpqr * in * f_I(c,r,j,k);
}
}
res = s;
}
else
res = 0.0;
inside[p][i][k] = res;
// printf("f_I p %s i %d k %d res %g\n",
// (const char *)nonterminal(p),i,k,res);
return res;
}
double EST_SCFG_traintest::f_O_cal(int c, int p, int i, int k)
{
// Find Outside probability
double res;
if ((i == 0) && (k == corpus.a_no_check(c).length()))
{
if (p == distinguished_symbol()) // distinguished non-terminal
res = 1.0;
else
res = 0.0;
}
else if (corpus.a_no_check(c).valid(i,k) == TRUE)
{
double s1=0.0;
double s2,s3;
double pBqrp,pBqpr;
int j;
int q,r;
for (q = 0; q < num_nonterminals(); q++)
for (r = 0; r < num_nonterminals(); r++)
{
pBqrp = prob_B(q,r,p);
s2 = s3 = 0.0;
if (pBqrp > 0)
{
for (j=0;j < i; j++)
{
double out = f_O(c,q,j,k);
if (out > 0)
s2 += out * f_I(c,r,j,i);
}
s2 *= pBqrp;
}
pBqpr = prob_B(q,p,r);
if (pBqpr > 0)
{
for (j=k+1;j <= corpus.a_no_check(c).length(); j++)
{
double out = f_O(c,q,i,j);
if (out > 0)
s3 += out * f_I(c,r,k,j);
}
s3 *= pBqpr;
}
s1 += s2 + s3;
}
res = s1;
}
else // not a valid bracketing
res = 0.0;
outside[p][i][k] = res;
return res;
}
void EST_SCFG_traintest::reestimate_rule_prob_B(int c, int ri, int p, int q, int r)
{
// Re-estimate probability for binary rules
int i,j,k;
double n2=0;
double pBpqr = prob_B(p,q,r);
if (pBpqr > 0)
{
for (i=0; i <= corpus.a_no_check(c).length()-2; i++)
for (j=i+1; j <= corpus.a_no_check(c).length()-1; j++)
{
double d1 = f_I(c,q,i,j);
if (d1 == 0) continue;
for (k=j+1; k <= corpus.a_no_check(c).length(); k++)
{
double d2 = f_I(c,r,j,k);
if (d2 == 0) continue;
double d3 = f_O(c,p,i,k);
if (d3 == 0) continue;
n2 += d1 * d2 * d3;
}
}
n2 *= pBpqr;
}
// f_P(c) is probably redundant
double fp = f_P(c);
double n1,d1;
n1 = n2 / fp;
if (fp == 0) n1=0;
d1 = f_P(c,p) / fp;
if (fp == 0) d1=0;
// printf("n1 %f d1 %f n2 %f fp %f\n",n1,d1,n2,fp);
n[ri] += n1;
d[ri] += d1;
}
void EST_SCFG_traintest::reestimate_rule_prob_U(int c,int ri, int p, int m)
{
// Re-estimate probability for unary rules
int i;
// printf("reestimate_rule_prob_U: %f p %s m %s\n",
// prob_U(ip,im),
// (const char *)p,
// (const char *)m);
double n2=0;
for (i=1; i < corpus.a_no_check(c).length(); i++)
if (m == terminal(corpus.a_no_check(c).symbol_at(i-1)))
n2 += prob_U(p,m) * f_O(c,p,i-1,i);
double fP = f_P(c);
if (fP != 0)
{
n[ri] += n2 / fP;
d[ri] += f_P(c,p) / fP;
}
}
double EST_SCFG_traintest::f_P(int c)
{
return f_I(c,distinguished_symbol(),0,corpus.a_no_check(c).length());
}
double EST_SCFG_traintest::f_P(int c,int p)
{
int i,j;
double db=0;
for (i=0; i < corpus.a_no_check(c).length(); i++)
for (j=i+1; j <= corpus.a_no_check(c).length(); j++)
{
double d1 = f_O(c,p,i,j);
if (d1 == 0) continue;
db += f_I(c,p,i,j)*d1;
}
return db;
}
void EST_SCFG_traintest::reestimate_grammar_probs(int passes,
int startpass,
int checkpoint,
int spread,
const EST_String &outfile)
{
// Iterate over the corpus cummulating factors for each rules
// This reduces the space requirements and recalculations of
// values for each sentences.
// Repeat training passes to number specified
int pass = 0;
double zero=0;
double se;
int ri,c;
n.resize(rules.length());
d.resize(rules.length());
for (pass = startpass; pass < passes; pass++)
{
EST_Litem *r;
double mC, lPc;
d.fill(zero);
n.fill(zero);
set_rule_prob_cache();
for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
{
// For skipping some sentences to speed up convergence
if ((spread > 0) && (((c+(pass*spread))%100) >= spread))
continue;
printf(" %d",c); fflush(stdout);
if (corpus.a_no_check(c).length() == 0) continue;
init_io_cache(c,num_nonterminals());
for (ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
{
if (rules(r).type() == est_scfg_binary_rule)
reestimate_rule_prob_B(c,ri,
rules(r).mother(),
rules(r).daughter1(),
rules(r).daughter2());
else
reestimate_rule_prob_U(c,
ri,
rules(r).mother(),
rules(r).daughter1());
}
lPc += safe_log(f_P(c));
mC += corpus.a_no_check(c).length();
clear_io_cache(c);
}
printf("\n");
for (se=0.0,ri=0,r=rules.head(); r != 0; r=r->next(),ri++)
{
double n_prob = n[ri]/d[ri];
if (d[ri] == 0)
n_prob = 0;
se += (n_prob-rules(r).prob())*(n_prob-rules(r).prob());
rules(r).set_prob(n_prob);
}
printf("pass %d cross entropy %g RMSE %f %f %d\n",
pass,-(lPc/mC),sqrt(se/rules.length()),
se,rules.length());
if (checkpoint != -1)
{
if ((pass % checkpoint) == checkpoint-1)
{
char cp[20];
sprintf(cp,".%03d",pass);
save(outfile+cp);
user_gc(NIL); // just to keep things neat
}
}
}
}
void EST_SCFG_traintest::train_inout(int passes,
int startpass,
int checkpoint,
int spread,
const EST_String &outfile)
{
// Train a Stochastic CFG using the inside outside algorithm
reestimate_grammar_probs(passes, startpass, checkpoint,
spread, outfile);
}
void EST_SCFG_traintest::init_io_cache(int c,int nt)
{
// Build an array to cache the in/out values
int i,j,k;
int mc = corpus.a_no_check(c).length()+1;
inside = new double**[nt];
outside = new double**[nt];
for (i=0; i < nt; i++)
{
inside[i] = new double*[mc];
outside[i] = new double*[mc];
for (j=0; j < mc; j++)
{
inside[i][j] = new double[mc];
outside[i][j] = new double[mc];
for (k=0; k < mc; k++)
{
inside[i][j][k] = -1;
outside[i][j][k] = -1;
}
}
}
}
void EST_SCFG_traintest::clear_io_cache(int c)
{
int mc = corpus.a_no_check(c).length()+1;
int i,j;
if (inside == 0)
return;
for (i=0; i < num_nonterminals(); i++)
{
for (j=0; j < mc; j++)
{
delete [] inside[i][j];
delete [] outside[i][j];
}
delete [] inside[i];
delete [] outside[i];
}
delete [] inside;
delete [] outside;
inside = 0;
outside = 0;
}
double EST_SCFG_traintest::cross_entropy()
{
double lPc=0,mC=0;
int c;
for (c=0; c < corpus.length(); c++)
{
lPc += log(f_P(c));
mC += corpus.a_no_check(c).length();
}
return -(lPc/mC);
}
void EST_SCFG_traintest::test_corpus()
{
// Test corpus against current grammar.
double mC,lPc;
int c,i;
int failed=0;
double fP;
// Lets try simply finding the cross entropy
n.resize(rules.length());
d.resize(rules.length());
for (i=0; i < rules.length(); i++)
d[i] = n[i] = 0.0;
for (mC=0.0,lPc=0.0,c=0; c < corpus.length(); c++)
{
if (corpus.length() > 50)
{
printf(" %d",c);
fflush(stdout);
}
init_io_cache(c,num_nonterminals());
fP = f_P(c);
if (fP == 0)
failed++;
else
{
lPc += safe_log(fP);
mC += corpus.a_no_check(c).length();
}
clear_io_cache(c);
}
if (corpus.length() > 50)
printf("\n");
cout << "cross entropy " << -(lPc/mC) << " (" << failed << " failed out of " <<
corpus.length() << " sentences )" << endl;
}