00001
00012 #ifdef HAVE_CONFIG_H
00013 #include "config.h"
00014 #else
00015 #ifdef _MSC_VER
00016 #include "msdevstudio/MSconfig.h"
00017 #endif
00018 #endif
00019
00020 #include "LMFitter.h"
00021
00022 #include "NumLinAlg.h"
00023 #include "StatedFCN.h"
00024
00025 #include <algorithm>
00026 #include <iostream>
00027
00028 #include <cmath>
00029 #include <cassert>
00030
00031 using namespace hippodraw::Numeric;
00032
00033 #ifdef ITERATOR_MEMBER_DEFECT
00034 using namespace std;
00035 #else
00036 using std::abs;
00037 using std::distance;
00038 using std::swap;
00039 using std::vector;
00040 using std::map;
00041 using std::string;
00042 using std::cout;
00043 using std::endl;
00044 #endif
00045
00046
00047 LMFitter::
00048 LMFitter ( const char * name )
00049 : Fitter ( name ),
00050 m_chi_cutoff ( 0.000001 ),
00051 m_start_lambda ( 0.001 ),
00052 m_lambda_shrink_factor( 9.8 ),
00053 m_lambda_expand_factor( 10.2 )
00054 {
00055 m_iter_params[ "chi_cutoff" ] = & m_chi_cutoff;
00056 m_iter_params[ "start_lambda"] = & m_start_lambda;
00057 m_iter_params[ "lambda_shrink_factor" ] = & m_lambda_shrink_factor;
00058 m_iter_params[ "lambda_expand_factor" ] = & m_lambda_expand_factor;
00059 }
00060
00061 Fitter *
00062 LMFitter::
00063 clone ( ) const
00064 {
00065 return new LMFitter ( *this );
00066 }
00067
00068 bool
00069 LMFitter::
00070 needsDerivatives () const
00071 {
00072 return true;
00073 }
00074
00077 void LMFitter::calcAlpha ()
00078 {
00079 m_fcn -> calcAlphaBeta ( m_alpha, m_beta );
00080 unsigned int num_parms = m_beta.size();
00081
00082 unsigned int j = 0;
00083 for ( ; j < num_parms; j++ ) {
00084 for ( unsigned int k = 0; k < j; k++ ) {
00085 m_alpha[k][j] = m_alpha[j][k];
00086 }
00087 }
00088
00089 j = 0;
00090 for ( ; j < num_parms; j++ ) {
00091 m_alpha[j][j] *= ( 1.0 + m_lambda );
00092 }
00093 }
00094
00100 int LMFitter::calcCovariance ( std::vector < std::vector < double > >& cov )
00101 {
00102 m_lambda = 0;
00103 calcAlpha ();
00104
00105
00106
00107
00108
00109 return invertMatrix ( m_alpha, cov );
00110 }
00111
00112 bool LMFitter::solveSystem ()
00113 {
00114 unsigned int num_parms = m_beta.size ();
00115
00116 vector< int > ipiv ( num_parms, 0 );
00117
00118 vector< int > indxr ( num_parms, -1 );
00119 vector< int > indxc ( num_parms, -1 );
00120
00121 unsigned int irow = UINT_MAX;
00122 unsigned int icol = UINT_MAX;
00123
00124 for ( unsigned int i = 0; i < num_parms; i++ ) {
00125 double big = 0.0;
00126
00127 for ( unsigned int j = 0; j < num_parms; j++ ) {
00128 if ( ipiv[j] != 1 ) {
00129
00130 for ( unsigned int k = 0; k < num_parms; k++ ) {
00131 if ( ipiv[k] == 0 ) {
00132 if ( abs ( m_alpha[j][k] ) >= big ) {
00133 big = abs ( m_alpha[j][k] );
00134 irow = j;
00135 icol = k;
00136 }
00137 }
00138 else if ( ipiv[k] > 1 ) {
00139 return false;
00140 }
00141 }
00142 }
00143 }
00144
00145 if ( irow == UINT_MAX ) {
00146 return false;
00147 }
00148
00149 ++ipiv[icol];
00150 if ( irow != icol ) {
00151 for ( unsigned int l = 0; l < num_parms; l++ ) {
00152 swap ( m_alpha[irow][l], m_alpha[icol][l] );
00153 }
00154 swap ( m_beta[irow], m_beta[icol] );
00155 }
00156 indxr[i] = irow;
00157 indxc[i] = icol;
00158 if ( m_alpha[icol][icol] == 0.0 ) {
00159 return false;
00160 }
00161 double pivinv = 1.0 / m_alpha[icol][icol];
00162 m_alpha[icol][icol] = 1.0;
00163
00164 for ( unsigned int l = 0; l < num_parms; l++ ) {
00165 m_alpha[icol][l] *= pivinv;
00166 }
00167 m_beta[icol] *= pivinv;
00168
00169 for ( unsigned int ll = 0; ll < num_parms; ll++ ) {
00170 if ( ll != icol ) {
00171 double dum = m_alpha[ll][icol];
00172 m_alpha[ll][icol] = 0.0;
00173
00174 for ( unsigned int l = 0; l < num_parms; l++ ) {
00175 m_alpha[ll][l] -= m_alpha[icol][l] * dum;
00176 }
00177 m_beta[ll] -= m_beta[icol] * dum;
00178 }
00179 }
00180 }
00181
00182 for ( int l = num_parms - 1; l >= 0; l-- ) {
00183 if ( indxr[l] != indxc[l] ) {
00184
00185 for ( unsigned int k = 0; k < num_parms; k++ ) {
00186 swap ( m_alpha[k][indxr[l]], m_alpha[k][indxc[l]] );
00187 }
00188 }
00189 }
00190 return true;
00191 }
00192
00193 bool LMFitter::calcStep ()
00194 {
00195 calcAlpha ();
00196 bool ok = solveSystem ();
00197
00198 return ok;
00199 }
00200
00201 bool LMFitter::calcBestFit ()
00202 {
00203 m_lambda = m_start_lambda;
00204
00205 int i = 0;
00206 for ( ; i < m_max_iterations; i++ ) {
00207
00208 double old_chisq = objectiveValue ();
00209
00210 vector< double > old_parms;
00211 m_fcn -> fillFreeParameters ( old_parms );
00212
00213 bool ok = calcStep ();
00214 assert ( old_parms.size() == m_beta.size() );
00215
00216 vector< double > new_parms ( old_parms );
00217 vector< double >::iterator pit = new_parms.begin ( );
00218 vector< double >::iterator dit = m_beta.begin ( );
00219
00220 while ( pit != new_parms.end () ) {
00221 *pit++ += *dit++;
00222 }
00223 m_fcn -> setFreeParameters ( new_parms );
00224
00225 double new_chisq = objectiveValue ();
00226
00227 if ( abs ( old_chisq - new_chisq ) < m_chi_cutoff ) break;
00228
00229 if ( new_chisq < old_chisq ) {
00230 m_lambda /= m_lambda_shrink_factor;
00231 }
00232 else {
00233 m_lambda *= m_lambda_expand_factor;
00234 m_fcn -> setFreeParameters ( old_parms );
00235 }
00236
00237 if ( ! ok ) return ok;
00238 }
00239
00240 return i < m_max_iterations;
00241 }
00242
00243 double LMFitter::iterParam( string name )
00244 {
00245
00246
00247 if( name == "max_iterations" )
00248 return m_max_iterations;
00249
00250
00251
00252 map< string, double * >::const_iterator it
00253 = m_iter_params.find ( name );
00254
00255 if ( it == m_iter_params.end () )
00256 cout << name << " is not a valid iteration parameter name" << endl;
00257 else
00258 return *m_iter_params[name];
00259
00260 return 0.0;
00261 }
00262
00263 int LMFitter::setIterParam( string name, double value )
00264 {
00265
00266
00267
00268 if( name == "max_iterations" )
00269 {
00270 m_max_iterations = ( int ) value;
00271 return EXIT_SUCCESS;
00272 }
00273
00274
00275
00276
00277 map< string, double * >::const_iterator it
00278 = m_iter_params.find ( name );
00279
00280 if ( it == m_iter_params.end () )
00281 {
00282 cout << name << " is not a valid iteration parameter name" << endl;
00283 return EXIT_FAILURE;
00284 }
00285 else
00286 {
00287 *m_iter_params[name] = value;
00288 return EXIT_SUCCESS;
00289 }
00290
00291 return EXIT_FAILURE;
00292 }