Jama comments


I would like to report my experience in rewriting my program em3.java
using Jama. The em3 was origianally written by extensive calls to the
DoubleMatrix and DoubleVector classes of Visual Numerics's JNL. The new
program em4.java uses only Jama's Matrix class. I have attached both
programs. The result: em4 runs at 1/2 to 1/3 the speed of em3.

I think one reason for this slow operation of em4 is that I have to
treat a one dimensional vector as a special case of a two-dimensional
Matrix like "new Matrix(m,1)" in em4 in order to use Jama's operation
like element-wise operation and matrix multiplication. I would like to
have a new class added to Jama, say MathVector. This new MathVector
class should have the corresponding methods as the Matrix class such as
plus, plusEquals. More importantly it should be able to multiply a
Matrix object with a MathVector object if the dimension matches. We can
adopt the Fortran 90's view that a vector will be row vector if it
multiplied by a Matrix from the right and will be a column vector if it
is multiplied by a Matrix from the left. For the case of a vector
multiplied by a vector we can introduce a innerProduct and outProduct
method.

It is my hope that a MathVector class will make the program run faster
and also clearer. By the way I am very grateful to Mathworks and NIST
for this Jama package.

Jason Liao
Assistant Professor of Biostatistics
University of South Florida

Jason Liao
13201 Bruce B. Downs Blvd.
Tampa, FL 33612

phone: 813-974-2951           
fax:   813-974-4718
jliao@com1.med.usf.edu          
http://www.biostat.coph.usf.edu/~liaoj/liaoj.html
 import Liao.*;
 import java.io.*;
 import java.util.*;
 import VisualNumerics.math.*;

 public final class em3
 {
     static final int n = 576;  //number of observations
     static final int m = 3;    //quadratic regression
     static final int num_of_curves = 3; //number of components in the mixture model

     static final int num_iter = 400000, burn_in = 20;

     static double[][] y = new double[n][]; //individual observation over time
     static double[][][] x = new double[n][][]; //individual design matrix
     static double[] t = {1,4,6,9,14,19,24,29,34}; //times of observation used to construct x

     static double[][] group_mean = new double[num_of_curves][m]; //group center coefficients
     static double[][] big_sigma_inv = new double[m][m];
     static double small_sigma2;

     static double[] prob = new double[num_of_curves]; //probability of belonging to one of the four curves
     static double[][] prob_local = new double[n][num_of_curves];  //marginal probability for each subject

     static double[][] theta = new double[n][]; //individual coefficients
     static int[] c_indi = new int[n]; //individual group indicator

     static double[][][] cross_xx = new double[n][][];  //saved for gene_theta
     static double[][] cross_xy = new double[n][];      //saved for gene_theta

   public static void main(String[] arg)
        throws IllegalArgumentException, MathException, IOException
   {
        hopkins_read.read_data(x, y, t, n);

        for(int i=0; i<n; i++)
           {
              cross_xx[i] = DoubleMatrix.multiply(DoubleMatrix.transpose(x[i]), x[i]);
              cross_xy[i] = DoubleMatrix.multiply(DoubleMatrix.transpose(x[i]), y[i]);
           }

        set_initials();

        Array.assigned_scalar(prob_local, 1./num_of_curves);

        EM_iteration();
   }

   public static void set_initials() throws IllegalArgumentException, MathException, IOException
   {
         Array.assigned_scalar(prob,1./num_of_curves);

         small_sigma2 = .00000001;

         for(int i=0; i<m; i++) big_sigma_inv[i][i] = 1.;

         for(int i=0; i<n; i++) c_indi[i]  = (int) (Math.random()*num_of_curves);

         theta = em1_subs.gene_theta(big_sigma_inv, cross_xx, cross_xy, c_indi, group_mean, small_sigma2);

         for(int i=0; i<n; i++)
            {
               int k = c_indi[i];
               group_mean[k] = Array.self_add(group_mean[k], theta[i]);
            }

          group_mean = Array.divide(group_mean, n/3.);

          double[][] big_sigma = new double[m][m];
          double[] diff;

          for(int i=0; i<n; i++)
             {
                int k = c_indi[i];
                diff = Array.subtract(theta[i], group_mean[k]);
                big_sigma = Array.self_add(big_sigma, Array.outer_product(diff, diff));
             }
             big_sigma = Array.divide(big_sigma, n);

             big_sigma_inv = DoubleMatrix.inverse(big_sigma);
    }

    public static void EM_iteration() throws IllegalArgumentException, MathException, IOException
    {
          int num_total_obs = 0;   //total number of time points observed
          for(int i=0; i<n; i++) num_total_obs += y[i].length;

          double[] c_term = new double[num_of_curves], c_local = new double[num_of_curves];   //all initialized to zero
          double small_sigma2_term = 0, small_sigma2_local;

          double[][] term1 = new double[m][m], term1_local = new double[m][m];
          double[][] term2_local = new double[num_of_curves][m], term2 = new double[num_of_curves][m], big_sigma;

          double weight;
          double[] mu, diff;
          int num_MH = 25;

          for(int iter=0; iter<num_iter; iter++)
          {
             big_sigma_inv = Array.self_multiply(big_sigma_inv, 1. + 20./(iter+1));

             for(int i=0; i<num_MH; i++) sample_MH();

             update_prob_local(iter);

             weight = 2./(iter + 2.);  //weight for the EM expectation

             for(int j=0; j<num_of_curves; j++)
             {
                c_local[j] = 0;
                for(int i=0; i<n; i++) if(c_indi[i] == j) c_local[j]++;
             }

             c_term = Array.self_add(Array.self_multiply(c_term, 1.-weight), Array.multiply(c_local, weight));

             if(iter > 10) prob = Array.divide(c_term, n);   //update

             small_sigma2_local = 0.;
             for(int i=0; i<n; i++)
                {
                   mu =  DoubleMatrix.multiply(x[i], theta[i]);
                   diff = Array.subtract(y[i], mu);
                   small_sigma2_local += DoubleVector.innerProduct(diff, diff);
                }

              small_sigma2_term = small_sigma2_term*(1.-weight) + small_sigma2_local*weight;

              if(iter > 10) small_sigma2 = small_sigma2_term / num_total_obs;      //update

              Array.assigned_scalar(term2_local, 0.);
              for(int j=0; j<num_of_curves; j++)
                  for(int i=0; i<n; i++)
                     if(c_indi[i] == j) term2_local[j] = Array.self_add(term2_local[j], theta[i]);

              term2 = Array.self_add(Array.self_multiply(term2, 1.-weight), Array.multiply(term2_local, weight));

              if(iter > 10) for(int j=0; j<num_of_curves; j++) group_mean[j] = Array.divide(term2[j], c_term[j]);  //update

              Array.assigned_scalar(term1_local, 0.);
              for(int i=0; i<n; i++)
                 for(int j=0; j<num_of_curves; j++)
                    {
                       if(c_indi[i] == j)
                       {
                          diff = Array.subtract(theta[i], group_mean[j]);
                          term1_local = Array.self_add(term1_local, Array.outer_product(diff, diff));
                       }
                    }

              term1 = Array.self_add(Array.self_multiply(term1, 1.-weight), Array.multiply(term1_local, weight));

              if(iter > 10)
              {
                 big_sigma = Array.divide(term1, n);     //update
                 big_sigma = Array.symmetrize(big_sigma);

                 big_sigma_inv = DoubleMatrix.inverse(big_sigma);
              }

              if(iter % 10 != 0) continue;
              System.out.println(iter);

              Array.print(group_mean);

              for(int i=0; i<num_of_curves; i++) System.out.print(Formatter.format(prob[i]));
              System.out.println();
          }
     }

     public static void update_prob_local(int iter)
     {
         double weight = 1./(2.*iter + 10.); //weight for updating prob_local only
         for(int i=0; i<n; i++)  //update the local marginal probability
         {
             for(int j=0; j<num_of_curves; j++)
              {
                 if(c_indi[i] == j) prob_local[i][j] = prob_local[i][j]*(1.- weight) + weight;
                 else prob_local[i][j] = prob_local[i][j]*(1.- weight);
              }
         }
     }

      public static void sample_MH() throws IllegalArgumentException, MathException
      {
         int c_indi_new;
         double[] theta_new;

         double[][] D_inv, D;
         double[] mu, temp, diff;
         double alpha, prob_approx_new, prob_approx_old;

         for(int i=0; i<n; i++)
         {
            c_indi_new = Liao.Random.multinomial(prob_local[i]);
            int k = c_indi_new;

            D_inv = Array.divide(cross_xx[i], small_sigma2);
            D_inv = Array.self_add(D_inv, big_sigma_inv);

            D = DoubleMatrix.inverse(D_inv);
            D = Array.symmetrize(D);

            mu = Array.divide(cross_xy[i], small_sigma2);

            temp = DoubleMatrix.multiply(big_sigma_inv, group_mean[k]);

            mu = Array.self_add(mu, temp);
            mu = DoubleMatrix.multiply(D, mu);

            theta_new = Liao.Random.normal(mu, D);

            diff = Array.subtract(theta_new, mu);
            prob_approx_new = Matrix.quadratic_form(D_inv, diff);
            prob_approx_new = Math.exp(-prob_approx_new/2.)*prob_local[i][k];

            diff = Array.subtract(theta[i], mu);
            prob_approx_old = Matrix.quadratic_form(D_inv, diff);
            k = c_indi[i];
            prob_approx_old = Math.exp(-prob_approx_old/2.)*prob_local[i][k];

            alpha = prob_true(c_indi_new, theta_new, x[i], y[i]) / prob_true(c_indi[i], theta[i], x[i], y[i]);
            alpha = alpha * prob_approx_old / prob_approx_new;

            //System.out.println("alpha  " +alpha);

            if(Math.random() < alpha)
                 {
                     c_indi[i] = c_indi_new;
                     theta[i] = theta_new; //no danger as elements of theta_new will not be manipulated
                 }
          }
     }

     public static double prob_true(int c_indi_single, double[] theta_single, double[][] x_single, double[] y_single) throws IllegalArgumentException, MathException
     {
        int k = c_indi_single;

        double[] diff = Array.subtract(theta_single, group_mean[k]);
        double probab1 = Matrix.quadratic_form(big_sigma_inv, diff);
        probab1 = Math.exp(-probab1/2.);

        double[] mu = DoubleMatrix.multiply(x_single, theta_single);
        diff = Array.subtract(y_single, mu);
        double probab2 = DoubleVector.innerProduct(diff, diff);
        probab2 = Math.exp(-probab2/small_sigma2/2.);

        double true_prob = prob[k]*probab1*probab2;
        return true_prob;
     }
  }

 import Jama.*;
 import java.io.*;
 import java.util.*;
 import Liao.*;

 public final class em4
 {
     static final int n = 576;  //number of observations
     static final int m = 3;    //quadratic regression
     static final int num_of_curves = 3; //number of components in the mixture model

     static final int num_iter = 10000, burn_in = 20;

     static Matrix[] y = new Matrix[n]; //individual observation over time
     static Matrix[] x = new Matrix[n]; //individual design matrix

     static double[] t = {1,4,6,9,14,19,24,29,34}; //times of observation used to construct x

     static Matrix[] group_mean = new Matrix[num_of_curves]; //group center coefficients, one dimensional matrix
     static Matrix big_sigma_inv;

     static double small_sigma2;

     static double[] prob = new double[num_of_curves]; //probability of belonging to one of the four curves
     static double[][] prob_local = new double[n][num_of_curves];  //marginal probability for each subject

     static Matrix[] theta = new Matrix[n]; //individual coefficients, one dimensional matrix
     static int[] c_indi = new int[n]; //individual group indicator

     static Matrix[] cross_xx = new Matrix[n];  //saved for gene_theta
     static Matrix[] cross_xy = new Matrix[n];      //saved for gene_theta

     static double alpha_sum, alpha_count;

   public static void main(String[] arg)
        throws IllegalArgumentException, IOException
   {
        hopkins_read.read_data(x, y, t, n);

        for(int i=0; i<n; i++) {
              cross_xx[i] = x[i].transpose().times(x[i]);
              cross_xy[i] = x[i].transpose().times(y[i]);  }

        for(int i=0; i<num_of_curves; i++) group_mean[i] = new Matrix(m,1);  //column vector
        for(int i=0; i<n; i++) theta[i] = new Matrix(m,1);                  //column vector

        set_initials();

        Array.assigned_scalar(prob_local, 1./num_of_curves);

        EM_iteration();
   }

   public static void set_initials() throws IllegalArgumentException, IOException
   {
         Liao.Array.assigned_scalar(prob,1./num_of_curves);

         small_sigma2 = .00000001;

         big_sigma_inv = Matrix.identity(m,m);

         for(int i=0; i<n; i++) c_indi[i]  = (int) (Math.random()*num_of_curves);

         gene_theta();

         for(int i=0; i<n; i++)  {
               int k = c_indi[i];
               group_mean[k] = group_mean[k].plusEquals(theta[i]);  }

          for(int i=0; i<num_of_curves; i++) group_mean[i] = group_mean[i].timesEquals(3./n);

          Matrix big_sigma = new Matrix(m,m);

          for(int i=0; i<n; i++)
             {
                int k = c_indi[i];
                Matrix diff = theta[i].minus(group_mean[k]);
                big_sigma = big_sigma.plusEquals(MatrixOperation.outer_product(diff, diff));
             }
             big_sigma = big_sigma.timesEquals(1./n);

             big_sigma_inv = big_sigma.inverse();
    }

    public static void EM_iteration() throws IllegalArgumentException, IOException
    {
          FileOutputStream fout1 = new FileOutputStream("em4_prob.out");
          PrintStream output_file1 = new PrintStream(fout1);

          FileOutputStream fout2 = new FileOutputStream("em4_alpha.out");
          PrintStream output_file2 = new PrintStream(fout2);

          int num_total_obs = 0;   //total number of time points observed
          for(int i=0; i<n; i++) num_total_obs += y[i].getRowDimension();

          int num_MH = 25;

          Matrix c_term = new Matrix(num_of_curves, 1), c_local = new Matrix(num_of_curves, 1);   //all initialized to zero

          double small_sigma2_term = 0, small_sigma2_local;

          Matrix term1 = new Matrix(m,m);

          Matrix[]  term2 = new Matrix[num_of_curves], term2_local = new Matrix[num_of_curves];

          for(int i=0; i<num_of_curves; i++) term2[i] = new Matrix(m,1);

          for(int iter=0; iter<num_iter; iter++)
          {
             System.gc();
             big_sigma_inv = big_sigma_inv.timesEquals(1. + 20./(iter+1));

             alpha_sum = 0.;
             alpha_count = 0.;

             for(int i=0; i<num_MH; i++) sample_MH();   //real simulations happen here

             update_prob_local(iter);

             double weight = 2./(iter + 2.);  //weight for the EM expectation

             for(int j=0; j<num_of_curves; j++) {
                double temp1 = 0.;
                for(int i=0; i<n; i++) if(c_indi[i] == j) temp1++;
                c_local.set(j, 0, temp1);  }

             c_term = c_term.timesEquals(1.-weight).plusEquals(c_local.times(weight));

             if(iter > 10) prob = c_term.times(1./n).getColumnPackedCopy();   //update

             small_sigma2_local = 0.;
             for(int i=0; i<n; i++) {
                   Matrix mu =  x[i].times(theta[i]);
                   Matrix diff = y[i].minus(mu);
                   small_sigma2_local += MatrixOperation.inner_product(diff, diff); }

              small_sigma2_term = small_sigma2_term*(1.-weight) + small_sigma2_local*weight;

              if(iter > 10) small_sigma2 = small_sigma2_term / num_total_obs;      //update

              for(int j=0; j<num_of_curves; j++) {
                  term2_local[j] = new Matrix(m, 1);
                  for(int i=0; i<n; i++) if(c_indi[i] == j) term2_local[j] = term2_local[j].plusEquals(theta[i]);  }

              for(int i=0; i<num_of_curves; i++) term2[i] = term2[i].timesEquals(1.- weight).plusEquals(term2_local[i].times(weight));

              if(iter > 10) for(int j=0; j<num_of_curves; j++) group_mean[j] = term2[j].times(1./c_term.get(j, 0));  //update

              Matrix term1_local = new Matrix(m, m);
              for(int i=0; i<n; i++)
                 for(int j=0; j<num_of_curves; j++)
                    {
                       if(c_indi[i] == j)  {
                          Matrix diff = theta[i].minus(group_mean[j]);
                          term1_local = term1_local.plusEquals(MatrixOperation.outer_product(diff, diff)); }
                    }

              term1 = term1.timesEquals(1.- weight).plusEquals(term1_local.times(weight));

              if(iter > 10) {
                 Matrix big_sigma = term1.times(1./n);     //update
                 big_sigma_inv = big_sigma.inverse();  }

              if(iter % 10 != 0) continue;
              System.out.println(iter);

              for(int i=0; i<num_of_curves; i++) group_mean[i].transpose().print(15, 8);
              for(int i=0; i<num_of_curves; i++) System.out.print(prob[i]+"  ");
              System.out.println();
              System.out.println("average alpha =  "+alpha_sum/alpha_count);

              for(int i=0; i<num_of_curves; i++) output_file1.print(prob[i]+"  ");
              output_file1.println();
              output_file2.println(alpha_sum/alpha_count);

          }
     }

     public static void update_prob_local(int iter)
     {
         double weight = 1./(2.*iter + 10.); //weight for updating prob_local only
         for(int i=0; i<n; i++)  //update the local marginal probability
         {
             for(int j=0; j<num_of_curves; j++)  {
                 if(c_indi[i] == j) prob_local[i][j] = prob_local[i][j]*(1.- weight) + weight;
                 else prob_local[i][j] = prob_local[i][j]*(1.- weight);  }
         }
     }

      public static void sample_MH() throws IllegalArgumentException
      {
         for(int i=0; i<n; i++)
         {
            int c_indi_new = Liao.Random.multinomial(prob_local[i]);
            int k = c_indi_new;

            Matrix D_inv = cross_xx[i].times(1./small_sigma2);
            D_inv = D_inv.plusEquals(big_sigma_inv);
            Matrix D = D_inv.inverse();

            Matrix mu = cross_xy[i].times(1./small_sigma2);
            Matrix temp = big_sigma_inv.times(group_mean[k]);
            mu = mu.plusEquals(temp);
            mu = D.times(mu);

            Matrix theta_new = Liao.Random.normal(mu, D);

            Matrix diff = theta_new.minus(mu);

            double prob_approx_new = MatrixOperation.quadratic_form(D_inv, diff);
            prob_approx_new = Math.exp(-prob_approx_new/2.)*prob_local[i][k];

            k = c_indi[i];
            diff = theta[i].minus(mu);
            double prob_approx_old = MatrixOperation.quadratic_form(D_inv, diff);
            prob_approx_old = Math.exp(-prob_approx_old/2.)*prob_local[i][k];

            double alpha = prob_true(c_indi_new, theta_new, x[i], y[i]) / prob_true(c_indi[i], theta[i], x[i], y[i]);
            alpha = alpha * prob_approx_old / prob_approx_new;

            alpha = Math.min(alpha, 1.);

            alpha_sum += alpha;
            alpha_count++;

            if(Math.random() < alpha)
                 {
                     c_indi[i] = c_indi_new;
                     theta[i] = theta_new; //no danger as elements of theta_new will not be manipulated
                 }
          }
     }

     public static double prob_true(int c_indi_single, Matrix theta_single, Matrix x_single, Matrix y_single)
            throws IllegalArgumentException
     {
        int k = c_indi_single;

        Matrix diff = theta_single.minus(group_mean[k]);
        double probab1 = MatrixOperation.quadratic_form(big_sigma_inv, diff);
        probab1 = Math.exp(-probab1/2.);

        Matrix mu = x_single.times(theta_single);

        diff = y_single.minus(mu);

        double probab2 = MatrixOperation.inner_product(diff, diff);
        probab2 = Math.exp(-probab2/small_sigma2/2.);

        return prob[k]*probab1*probab2;

     }


       public static void gene_theta() throws IllegalArgumentException
     {

         for(int i=0; i<n; i++)
           {
               Matrix D_inv = cross_xx[i].times(1./small_sigma2);
               D_inv = D_inv.plusEquals(big_sigma_inv);

               Matrix D = D_inv.inverse();

               Matrix mu = cross_xy[i].times(1./small_sigma2);

               int k = c_indi[i];
               Matrix temp = big_sigma_inv.times(group_mean[k]);

               mu = mu.plusEquals(temp);
               mu = D.times(mu);

               theta[i] = Liao.Random.normal(mu, D);
            }
      }
  }



Date Index | Thread Index | Problems or questions? Contact list-master@nist.gov