From 4f7c412d650e8493e0a356f4159be650cffb4d0b Mon Sep 17 00:00:00 2001 From: Thomas White Date: Fri, 7 Jul 2023 15:02:45 +0200 Subject: Factorise matrix operations This makes the code much clearer. Note that two opposing sign errors have been fixed in the gradient calculation. --- libcrystfel/src/detgeom.c | 44 +----------------------------- libcrystfel/src/predict-refine.c | 17 ++++-------- libcrystfel/src/utils.c | 59 ++++++++++++++++++++++++++++++++++++++++ libcrystfel/src/utils.h | 3 ++ 4 files changed, 69 insertions(+), 54 deletions(-) (limited to 'libcrystfel') diff --git a/libcrystfel/src/detgeom.c b/libcrystfel/src/detgeom.c index 8387dff4..397edc6a 100644 --- a/libcrystfel/src/detgeom.c +++ b/libcrystfel/src/detgeom.c @@ -230,48 +230,6 @@ void detgeom_translate_detector_m(struct detgeom *dg, double x, double y, double } -static gsl_matrix *invert(gsl_matrix *m) -{ - gsl_permutation *perm; - gsl_matrix *inv; - int s; - - perm = gsl_permutation_alloc(m->size1); - if ( perm == NULL ) { - ERROR("Couldn't allocate permutation\n"); - gsl_matrix_free(m); - return NULL; - } - - inv = gsl_matrix_alloc(m->size1, m->size2); - if ( inv == NULL ) { - ERROR("Couldn't allocate inverse\n"); - gsl_matrix_free(m); - gsl_permutation_free(perm); - return NULL; - } - - if ( gsl_linalg_LU_decomp(m, perm, &s) ) { - ERROR("Couldn't decompose matrix\n"); - gsl_matrix_free(m); - gsl_permutation_free(perm); - return NULL; - } - - if ( gsl_linalg_LU_invert(m, perm, inv) ) { - ERROR("Couldn't invert cell matrix:\n"); - gsl_matrix_free(m); - gsl_permutation_free(perm); - return NULL; - } - - gsl_permutation_free(perm); - gsl_matrix_free(m); - - return inv; -} - - gsl_matrix **make_panel_minvs(struct detgeom *dg) { int i; @@ -295,7 +253,7 @@ gsl_matrix **make_panel_minvs(struct detgeom *dg) gsl_matrix_set(M, 2, 1, p->fsz); gsl_matrix_set(M, 2, 2, p->ssz); - Minvs[i] = invert(M); + Minvs[i] = matrix_invert(M); if ( Minvs[i] == NULL ) { ERROR("Failed to calculate inverse panel matrix for %s\n", p->name); diff --git a/libcrystfel/src/predict-refine.c b/libcrystfel/src/predict-refine.c index fda3eaa4..00de8bb8 100644 --- a/libcrystfel/src/predict-refine.c +++ b/libcrystfel/src/predict-refine.c @@ -156,8 +156,7 @@ int fs_ss_gradient(int param, Reflection *refl, UnitCell *cell, gsl_matrix *M; double mu; gsl_matrix *dMdp; - gsl_matrix *minusMinvdMdp; - gsl_matrix *minusMinvdMdpMinv; + gsl_matrix *gM; /* M^-1 * dM/dx * M^-1 */ double fs, ss; get_indices(refl, &h, &k, &l); @@ -287,19 +286,15 @@ int fs_ss_gradient(int param, Reflection *refl, UnitCell *cell, } - minusMinvdMdp = gsl_matrix_calloc(3, 3); - gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, -1.0, Minv, dMdp, 0.0, minusMinvdMdp); - minusMinvdMdpMinv = gsl_matrix_calloc(3, 3); - gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, -1.0, minusMinvdMdp, Minv, 0.0, minusMinvdMdpMinv); - gsl_blas_dgemv(CblasNoTrans, 1.0, minusMinvdMdpMinv, t, 0.0, v); + gM = matrix_mult3(Minv, dMdp, Minv); + gsl_blas_dgemv(CblasNoTrans, -1.0, gM, t, 0.0, v); - *fsg = -(mu*gsl_vector_get(v, 1) - mu*fs*gsl_vector_get(v, 0)); - *ssg = -(mu*gsl_vector_get(v, 2) - mu*ss*gsl_vector_get(v, 0)); + *fsg = mu*(gsl_vector_get(v, 1) - fs*gsl_vector_get(v, 0)); + *ssg = mu*(gsl_vector_get(v, 2) - ss*gsl_vector_get(v, 0)); gsl_vector_free(v); - gsl_matrix_free(minusMinvdMdpMinv); + gsl_matrix_free(gM); gsl_matrix_free(dMdp); - gsl_matrix_free(minusMinvdMdp); gsl_vector_free(t); return 0; diff --git a/libcrystfel/src/utils.c b/libcrystfel/src/utils.c index ae423eed..262a32cc 100644 --- a/libcrystfel/src/utils.c +++ b/libcrystfel/src/utils.c @@ -109,6 +109,65 @@ void show_vector(gsl_vector *v) } +gsl_matrix *matrix_mult(gsl_matrix *A, gsl_matrix *B) +{ + gsl_matrix *r = gsl_matrix_calloc(A->size1, A->size2); + gsl_blas_dgemm(CblasNoTrans, CblasNoTrans, 1.0, A, B, 0.0, r); + return r; +} + + +gsl_matrix *matrix_mult3(gsl_matrix *A, gsl_matrix *B, gsl_matrix *C) +{ + gsl_matrix *tmp = matrix_mult(B, C); + gsl_matrix *r = matrix_mult(A, tmp); + gsl_matrix_free(tmp); + return r; +} + + +gsl_matrix *matrix_invert(gsl_matrix *m) +{ + gsl_permutation *perm; + gsl_matrix *inv; + int s; + + perm = gsl_permutation_alloc(m->size1); + if ( perm == NULL ) { + ERROR("Couldn't allocate permutation\n"); + gsl_matrix_free(m); + return NULL; + } + + inv = gsl_matrix_alloc(m->size1, m->size2); + if ( inv == NULL ) { + ERROR("Couldn't allocate inverse\n"); + gsl_matrix_free(m); + gsl_permutation_free(perm); + return NULL; + } + + if ( gsl_linalg_LU_decomp(m, perm, &s) ) { + ERROR("Couldn't decompose matrix\n"); + gsl_matrix_free(m); + gsl_permutation_free(perm); + return NULL; + } + + if ( gsl_linalg_LU_invert(m, perm, inv) ) { + ERROR("Couldn't invert cell matrix:\n"); + gsl_matrix_free(m); + gsl_permutation_free(perm); + return NULL; + } + + gsl_permutation_free(perm); + gsl_matrix_free(m); + + return inv; +} + + static int check_eigen(gsl_vector *e_val, int verbose) { int i; diff --git a/libcrystfel/src/utils.h b/libcrystfel/src/utils.h index 82a9aa1c..6d2ff253 100644 --- a/libcrystfel/src/utils.h +++ b/libcrystfel/src/utils.h @@ -77,6 +77,9 @@ extern void show_matrix(gsl_matrix *M); extern void show_vector(gsl_vector *M); extern gsl_vector *solve_svd(gsl_vector *v, gsl_matrix *M, int *n_filt, int verbose); +extern gsl_matrix *matrix_mult2(gsl_matrix *A, gsl_matrix *B); +extern gsl_matrix *matrix_mult3(gsl_matrix *A, gsl_matrix *B, gsl_matrix *C); +extern gsl_matrix *matrix_invert(gsl_matrix *m); extern size_t notrail(char *s); extern int convert_int(const char *str, int *pval); -- cgit v1.2.3