diff options
-rw-r--r-- | src/predict-refine.c | 138 | ||||
-rw-r--r-- | tests/prediction_gradient_check.c | 28 |
2 files changed, 148 insertions, 18 deletions
diff --git a/src/predict-refine.c b/src/predict-refine.c index 56a4cc2c..8de3885f 100644 --- a/src/predict-refine.c +++ b/src/predict-refine.c @@ -210,8 +210,8 @@ static double r_gradient(UnitCell *cell, int k, Reflection *refl, /* Returns d(xh-xpk)/dP + d(yh-ypk)/dP, where P = any parameter */ -static double pos_gradient(int param, struct reflpeak *rp, struct detector *det, - double lambda, UnitCell *cell) +static double x_gradient(int param, struct reflpeak *rp, struct detector *det, + double lambda, UnitCell *cell) { signed int h, k, l; double xpk, ypk, xh, yh; @@ -226,7 +226,7 @@ static double pos_gradient(int param, struct reflpeak *rp, struct detector *det, tt = asin(lambda * resolution(cell, h, k, l)); clen = rp->panel->clen; azi = atan2(yh, xh); - azf = 2.0*(cos(azi) + sin(azi)); /* FIXME: Why factor of 2? */ + azf = 2.0*cos(azi); /* FIXME: Why factor of 2? */ switch ( param ) { @@ -240,6 +240,61 @@ static double pos_gradient(int param, struct reflpeak *rp, struct detector *det, return l * lambda * clen / cos(tt); case REF_ASY : + return 0.0; + + case REF_BSY : + return 0.0; + + case REF_CSY : + return 0.0; + + case REF_ASZ : + return -h * lambda * clen * azf * sin(tt) / (cos(tt)*cos(tt)); + + case REF_BSZ : + return -k * lambda * clen * azf * sin(tt) / (cos(tt)*cos(tt)); + + case REF_CSZ : + return -l * lambda * clen * azf * sin(tt) / (cos(tt)*cos(tt)); + + } + + ERROR("Positional gradient requested for parameter %i?\n", param); + abort(); +} + + +/* Returns d(yh-ypk)/dP, where P = any parameter */ +static double y_gradient(int param, struct reflpeak *rp, struct detector *det, + double lambda, UnitCell *cell) +{ + signed int h, k, l; + double xpk, ypk, xh, yh; + double fsh, ssh; + double tt, clen, azi, azf; + + twod_mapping(rp->peak->fs, rp->peak->ss, &xpk, &ypk, rp->panel); + get_detector_pos(rp->refl, &fsh, &ssh); + twod_mapping(fsh, ssh, &xh, &yh, rp->panel); + get_indices(rp->refl, &h, &k, &l); + + tt = asin(lambda * resolution(cell, h, k, l)); + clen = rp->panel->clen; + azi = atan2(yh, xh); + azf = 2.0*sin(azi); /* FIXME: Why factor of 2? */ + + switch ( param ) { + + case REF_ASX : + return 0.0; + + case REF_BSX : + return 0.0; + + case REF_CSX : + return 0.0; + + case REF_ASY : return h * lambda * clen / cos(tt); case REF_BSY : @@ -273,7 +328,7 @@ static double r_dev(struct reflpeak *rp) } -static double pos_dev(struct reflpeak *rp, struct detector *det) +static double x_dev(struct reflpeak *rp, struct detector *det) { /* Peak position term */ double xpk, ypk, xh, yh; @@ -281,7 +336,19 @@ static double pos_dev(struct reflpeak *rp, struct detector *det) twod_mapping(rp->peak->fs, rp->peak->ss, &xpk, &ypk, rp->panel); get_detector_pos(rp->refl, &fsh, &ssh); twod_mapping(fsh, ssh, &xh, &yh, rp->panel); - return (xh-xpk) + (yh-ypk); + return xh-xpk; +} + + +static double y_dev(struct reflpeak *rp, struct detector *det) +{ + /* Peak position term */ + double xpk, ypk, xh, yh; + double fsh, ssh; + twod_mapping(rp->peak->fs, rp->peak->ss, &xpk, &ypk, rp->panel); + get_detector_pos(rp->refl, &fsh, &ssh); + twod_mapping(fsh, ssh, &xh, &yh, rp->panel); + return yh-ypk; } @@ -339,10 +406,42 @@ static int iterate(struct reflpeak *rps, int n, UnitCell *cell, } - /* Positional terms */ + /* Positional x terms */ + for ( k=0; k<9; k++ ) { + gradients[k] = x_gradient(k, &rps[i], image->det, + image->lambda, cell); + } + + for ( k=0; k<9; k++ ) { + + int g; + double v_c, v_curr; + + for ( g=0; g<9; g++ ) { + + double M_c, M_curr; + + /* Matrix is symmetric */ + if ( g > k ) continue; + + M_c = gradients[g] * gradients[k]; + M_curr = gsl_matrix_get(M, k, g); + gsl_matrix_set(M, k, g, M_curr + M_c); + gsl_matrix_set(M, g, k, M_curr + M_c); + + } + + v_c = x_dev(&rps[i], image->det); + v_c *= -gradients[k]; + v_curr = gsl_vector_get(v, k); + gsl_vector_set(v, k, v_curr + v_c); + + } + + /* Positional y terms */ for ( k=0; k<9; k++ ) { - gradients[k] = pos_gradient(k, &rps[i], image->det, - image->lambda, cell); + gradients[k] = y_gradient(k, &rps[i], image->det, + image->lambda, cell); } for ( k=0; k<9; k++ ) { @@ -364,7 +463,7 @@ static int iterate(struct reflpeak *rps, int n, UnitCell *cell, } - v_c = pos_dev(&rps[i], image->det); + v_c = y_dev(&rps[i], image->det); v_c *= -gradients[k]; v_curr = gsl_vector_get(v, k); gsl_vector_set(v, k, v_curr + v_c); @@ -407,11 +506,28 @@ static double residual(struct reflpeak *rps, int n, struct detector *det) { int i; double res = 0.0; + double r; + + r = 0.0; + for ( i=0; i<n; i++ ) { + r += EXC_WEIGHT * rps[i].Ih * pow(r_dev(&rps[i]), 2.0); + } + printf("%e ", r); + res += r; + + r = 0.0; + for ( i=0; i<n; i++ ) { + r += pow(x_dev(&rps[i], det), 2.0); + } + printf("%e ", r); + res += r; + r = 0.0; for ( i=0; i<n; i++ ) { - res += EXC_WEIGHT * rps[i].Ih * pow(r_dev(&rps[i]), 2.0); - res += pow(pos_dev(&rps[i], det), 2.0); + r += pow(y_dev(&rps[i], det), 2.0); } + printf("%e ", r); + res += r; return res; } diff --git a/tests/prediction_gradient_check.c b/tests/prediction_gradient_check.c index 15b4d803..953d5fde 100644 --- a/tests/prediction_gradient_check.c +++ b/tests/prediction_gradient_check.c @@ -85,7 +85,11 @@ static void scan(RefList *reflections, RefList *compare, break; case 1 : - vals[idx][i] = xh + yh; + vals[idx][i] = xh; + break; + + case 2 : + vals[idx][i] = yh; break; } @@ -290,10 +294,17 @@ static double test_gradients(Crystal *cr, double incr_val, int refine, rp.peak = &pk; rp.panel = &image->det->panels[0]; - cgrad = pos_gradient(refine, &rp, - crystal_get_image(cr)->det, - crystal_get_image(cr)->lambda, - crystal_get_cell(cr)); + if ( checkrxy == 1 ) { + cgrad = x_gradient(refine, &rp, + crystal_get_image(cr)->det, + crystal_get_image(cr)->lambda, + crystal_get_cell(cr)); + } else { + cgrad = y_gradient(refine, &rp, + crystal_get_image(cr)->det, + crystal_get_image(cr)->lambda, + crystal_get_cell(cr)); + } } get_partial(refl, &r1, &r2, &p); @@ -431,7 +442,7 @@ int main(int argc, char *argv[]) rng = gsl_rng_alloc(gsl_rng_mt19937); - for ( checkrxy=1; checkrxy<2; checkrxy++ ) { + for ( checkrxy=0; checkrxy<3; checkrxy++ ) { switch ( checkrxy ) { @@ -439,9 +450,12 @@ int main(int argc, char *argv[]) STATUS("Excitation error:\n"); break; case 1: - STATUS("x+y coordinate:\n"); + STATUS("x coordinate:\n"); break; default: + case 2: + STATUS("y coordinate:\n"); + break; STATUS("WTF??\n"); break; } |