/* ==== @(#) COPYRIGHT (C) The Australian National University 1997,1998 ==== */

/* SCCS INFO @(#)Lookahead.c 1.3         last modified P. Strazdins, 98/04/28
 *  first created: P. Strazdins, DCS ANU 98/04
 *
 * any problems, queries or bug reports for this program should be directed to:
 *	peter@cs.anu.edu.au
 */

static char info[]="    
Lookahead: a program simulating lookahead according to Section 2 of the paper:
	A Comparison of Lookahead and Algorithmic Blocking
	Techniques for Parallel Matrix Factorization
Lookahead is over a TRSM-like computation over a 1xQ processor array;
 	the matrix is nQ block columns, with a lookahead factor bounded
	above by lambda. rho (kappa) is the normalized panel formation
	(communication) costs.
the following options for the computation are available:
options:	meaning:			Equations of the paper used: 	
  '-w': 	use wormhole broadcasts			(1)-(4)
  '-w -h':	use wormhole b/c with handshaking	(1),(2),(4),(8)	
  '':		use ring broadcast			(1),(3),(4),(11)
  '-h':		use ring b/c with handshaking		(1),(3),(4),(12)
If the option '-t' is specified, assertions corresponding to the appropriate
selection of Properties (5), (6), (7), (9), (10) and (14) are tested. 
Non-integral values of lambda, rho and kappa may be used, but these
should have exact binary representation. 
The option '-v v' sets the verbosity level of the output:
   v=0:		minimal verbosity (default). print commun stalls for each cell;
   		also residual stalls in last cell, and total time. 
   		Also info in each cell on max # of unread messages.
   v=1:		for each cell and each iteration, at each communication point,
   		print # of unread messages, and the time between the start of 
   		the current broadcast and the completion of prev. itn. 
   v=2:		for each cell and each iteration, also print tC[]	
   v=3:		for each cell and each iteration, also print tM[] 
   v=4:		for each cell and each iteration, also print n(i,q) and lambda_i
   v=-1:	print a time line visualization of the computation,    	
 		with time on the vertical axis. If lambda, kappa and rho
 		are non-integral, setting scale so that scale*lambda, 
 		scale*kaffa and scale*rho are all integral will permit the
 		visualization to exactly represent the computation.
Examples (the following were used to draw  Fig 7 of the paper):	
	Lookahead -v -1 -w 1.0 1.0 4 0 4 		
	Lookahead -v -1 -w 1.0 1.0 4 1 4 		
	Lookahead -v -1    1.0 1.0 4 2 4
";

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>


double lambda=0.0, rho=1.0, kappa=0.0;
int n=1, Q=1;	
int wormbc=0, handshake=0;	
int verbosity=0, testthm=0, scale=1;

void PrintUsage()
{ printf("usage: Lookahead -w -h -v v -s scale -t rho kappa n lambda Q\n");
  printf("%s\n", info);
}

void PrintParameters()
{ printf("lookahead 1x%d array wormbc=%d handshake=%d, rho=%1.3f, kappa=%0.4f, n=%d, lambda=%1.2f\n",
	  Q, wormbc, handshake, rho, kappa, n, lambda); 
}
	
void GetParameters(argc, argv)
  int argc; char *argv[];
{ extern char *optarg;
  extern int optind;
  int c; 
  
  while ((c = getopt(argc, argv, "twhv:s:")) != -1)
    switch (c) {
      case 'w':
        wormbc=1; 			break;
      case 'h':
        handshake=1; 			break;
      case 't':
        testthm=1; 			break;
      case 'v':
        verbosity=atoi(optarg); 	break;
      case 's':
        scale=atoi(optarg); 	break;
      case '?':
        PrintUsage();
        exit(1);
    }
    
  if (optind+4 >= argc) {
     PrintUsage();           
     exit(1);	    
  }
  rho = atof(argv[optind]);
  kappa = atof(argv[optind+1]);
  n = atoi(argv[optind+2]);
  lambda = atof(argv[optind+3]);
  Q = atoi(argv[optind+4]);
  if (lambda < 0) lambda = n;
  
} /* GetParameters() */

/********************* code for time line visual display ********************/
#define MAX(a, b) ((a) >= (b)? (a): (b))
#define MIN(a, b) ((a) <= (b)? (a): (b))
#define MAXQ 64

typedef struct tv{
  char type;
  int i;
} tval;  	

#define NONE ' '
#define FACT 'F'
#define COMM 'c'
#define MULT 'M'
#define HS 'H'

tval *time[MAXQ]; int tmax = 0, maxtv = 0;	
  
void record_tv(char v, int i, double T0, double T1, int q) 
{ int t0 = (scale*T0+0.5), t1 = scale*T1 + 0.5, t;
  if (verbosity!=-1) return;
  assert(0<=q && q<=Q && time[q]!=NULL);
  assert(t0<=t1 && t1<=tmax);
  if (v != NONE)
    maxtv = MAX(t1, maxtv);
  for (t=t0; t < t1; t++)
      time[q][t].type = v, time[q][t].i = i; 
} /* record_tv() */

void print_tv() 
{ int q, t;
#if 0
  printf("maxtv=%d, tmax=%d\n", maxtv, tmax);
#endif  
  for (t=0; t < maxtv; t++) {
    if (scale==1)	
      printf("%3d | ", t);	
    else  
      printf("%6.3f | ", t/((double) scale));	
    for (q=0; q<Q; q++) {
      if (time[q][t].type!=NONE) 
        if (time[q][t].i>=0)
          printf("%c%-3d| ", time[q][t].type, time[q][t].i);	 	
        else  
          printf("%c   | ", time[q][t].type);	 	
      else
        printf("    | ");
    }    
    printf("\n");
  }  
} /* print_tv() */


/************************  misc functions & main ***************************/

int niq(i, q) 
  int i, q;
{ int qi = i%Q;
  return (n - i/Q - (q<=qi)); 	
}  

#define delta(i, q) (((i)-(q))%Q==0)

main(argc,argv)
  int argc; char *argv[];
{ int i, q;   

  double tM[MAXQ],			/* tM[q]=time cell q completes itn i*/
  	 tC[MAXQ], 			/* tC[q]=time cell q begins commun itn i*/
  	 tCs[MAXQ], 			/* stores tC[qi] for last Q itns */
  	 tCp[MAXQ][MAXQ], 		/* stores tC[] for last Q itns */
  	 tS[MAXQ]; 			/* tS[q]=accum. commun. stalls cell q*/
  int jm[MAXQ], 			/* At itn. pjm[q], jm[q] is max. # of */
      pjm[MAXQ]; 			/* unread messages at cell q; djm[q] */
  double djm[MAXQ]; 			/* is time when last message was read */
  int pjm_[MAXQ];			/* At itn. pjm_[q], djm_[q] is max.   */
  double djm_[MAXQ]; 			/* diff between time of finish reading*/
					/* msg. of prev itn & current send    */
  GetParameters(argc,argv);
  PrintParameters();
  if (Q > MAXQ) {
    printf("Q=%d too large; redefine constant MAXQ=%d\n", Q, MAXQ);
    exit(1);
  }

  if (verbosity==-1) {
    tmax = (int) n*Q*((n-1)*Q/2+rho + kappa)*scale;	
    for (q=0; q < Q; q++) {
      time[q] = malloc(tmax*sizeof(tval));
      record_tv(NONE, 0, 0.0, (double) tmax/scale, q);         
    }   
  }
  
  tM[0] = 0.0;			/* boundary conditions */
  for (q=0; q < Q; q++) {
    tM[q] = 0.0, tC[q]=-kappa, tS[q] = 0.0;;
    jm[q] = 0, pjm[q] = -1;
    djm[q] = djm_[q] = -1.0, pjm_[q] = -1;
  } 
  for (i=0; i < Q; i++)
    for (q=0; q < Q; q++)
      tCp[i][q] = 0.0;

  for (i=0; i < Q*n; i++) {
    int dq, q, qi = i%Q;
    tCp[qi][i%Q] = tC[qi];           	
    tC[qi] = tCs[i%Q] = tM[qi] + rho;		    
    record_tv(FACT, i, tM[qi], tC[qi], qi);

    if (handshake) {			/* note: tC[] has values of prev itn. */
      double th=tC[qi];	
      if (wormbc) {
        for (dq=1; dq < Q; dq++) 
          th = MAX(th, tC[(qi+dq)%Q]+kappa);
      } else {
        th = MAX(th, tC[(qi+1)%Q]+kappa);
        if (testthm)	
          /* for case of ring b/c: handshaking not needed at source */
          assert(th==tC[qi]);
      }  
      record_tv(HS, -1, tC[qi], th, qi);  
      tS[qi] += th-tC[qi];  
      tCs[i%Q] = tC[qi] = MAX(tC[qi], th);              
    }    

    for (dq=1; dq < Q; dq++) {
      q = (qi+dq)%Q;
      tCp[q][i%Q] = tC[q];
      if (wormbc) {
        tC[q] = MAX(tM[q], tCs[i%Q]);
        tS[q] += tC[q] - tM[q];
        record_tv(NONE, i, tM[q], tC[q], q); 
       
        if (testthm && lambda <= rho+1)
          /* Property 7: handshaking will not be needed if lookahead is small*/ 
          assert(tCp[q][i%Q] + kappa <= tCs[i%Q]);  
                
      } else {
        tC[q] = MAX(tM[q], tC[(Q+q-1)%Q]+kappa);
        record_tv(NONE, i, tM[q], tC[q], q);
        tS[q] += tC[q] - tM[q];

        if (handshake && dq < Q-1) {          
          double th=MAX(tC[q], tC[(q+1)%Q]+kappa);
          record_tv(HS, -1, tC[q], th, q);  
          if (testthm)
            /* for ring b/c: handshaking not needed at intermediate cells */
            assert(th == tC[q]);	
          tC[q] = th;
          tS[q] += th-tC[q];                              
        } 
      }  
    } /* for (dq...) */
    
    if (testthm && i>0 && qi!=Q-1) {
      int niqm1 = niq(i, qi)-1;    	/*test if sufficient lookahead achieved*/
      double lambda_i = MAX(0.0, MIN(lambda, niqm1));
      if (lambda_i>=rho/*+(!wormbc? MAX(kappa-1.0, 0.0): 0.0)*/) {
      	if (wormbc) {
      	  /* Property 6: no stalls at cell Q-1 */ 
          assert(tM[Q-1]>=tCs[i%Q]);
        } else {  
          double TCs = tC[Q-2] - /*pipeline bubble*/(qi==0? Q*kappa -
          	       MIN(Q*kappa, MIN(lambda_i, rho+1)) : 0.0);
      	  /* Properties 6 & 13: no stalls at cell Q-1, except for pipeline bubble
      	     where qi=0 */ 
          assert(tM[Q-1]>=TCs);
        }  
      }  
    }
                           
    if (testthm && wormbc && handshake) {
      double nohs_tCs = tM[qi] + rho;
      double nohs_Cq = MAX(tM[Q-1], nohs_tCs);	
      assert(nohs_Cq+kappa >= tCs[i%Q]);
      /* Property 10: handshaking didnt add stall at cell Q-1*/
      if (qi==Q-1)			
        /* Property 9: handshaking did not delay send at cell Q-1 */   	
        assert(tCs[i%Q] == nohs_tCs);   
    }
         
    if (verbosity>=1)
      printf("%3d: ", i);
    for (q=0; q < Q; q++) {
      int j=0;
      double tC_snd_i = wormbc? tCs[i%Q]: tC[(Q+q-1)%Q];
      if (q!=qi)
        while (j<=i && (tC_snd_i <= tCp[q][(Q+i-j)%Q]+kappa) && j < Q) 
          j++;
      if (q!=qi && tCp[q][i%Q]+kappa-tC_snd_i > djm_[q])
        djm_[q] = tCp[q][i%Q]+kappa-tC_snd_i, pjm_[q]=i; 
      if (j > jm[q] && q!=qi && i>0) {
        djm[q] = tCp[q][(Q+i-j+1)%Q]+kappa - tC_snd_i;
        jm[q] = j, pjm[q] = i;
#if 0
        printf("i=%d, q=%d ts=%6.2f j=%d, last recv=%6.2f \n", 
        	i, q, tC_snd_i, j, tCp[q][(Q+i-j+1)%Q] + (wormbc? kappa: 0.0));
#endif        	
      }  
      if (verbosity>=1) {  
        if (q==qi)
          printf("** ");
        else  
          printf("%2d ", j);
        printf("%6.2f, ", tM[q]-tC_snd_i);   
      }  
    } /* for (q...) */ 
    if (verbosity>=1)
      printf("\n");
         
    if (verbosity>=4)
      printf("(nl:)");      
    for (q=0; q < Q; q++) {
      int nq = niq(i, q);	
      double lambda_i = MAX(0.0, MIN(nq-1.0, lambda));	
      double lambda_im1 = (i==0 && q==0)? 0.0: 
      			  MAX(0.0, MIN(niq(i-1, q)-1.0, lambda));	
      double tCK = tC[q] + kappa;
      record_tv(COMM, i, tC[q], tCK, q);          
      tM[q] = tCK + (nq - delta(i+1, q)*lambda_i) + delta(i,q)*lambda_im1;
      record_tv(MULT, i-1, tCK, tCK+delta(i,q)*lambda_im1, q); 	      
      record_tv(MULT, i, tCK+delta(i,q)*lambda_im1, tM[q], q);
      
      if (verbosity>=4) {
        printf("%2d ", nq);
        if (delta(i+1, q))
          printf("%4.1f ", lambda_i);
        else 
          printf(".... "); 
        if (delta(i, q))
          printf("%d, ", (int) lambda_im1);
        else 
          printf(" , "); 
      }         
    } /* for (q...) */

    for (q=MAX(0, i-(n-1)*Q); q < (Q-1)*testthm; q++) {
      /* take into account extra lags from pipelined communications
         where cell q is downstream of cell Q-1 */	
      double ringbc_lag = (wormbc && q<qi? 0.0: (q+1)*kappa);	
      if (/*wormbc &&*/ qi==Q-2) {
      	double lambda_i = MAX(0.0, MIN(niq(i, Q-1)-1.0, lambda));
	/* Property 5: cell Q-1 is on the critical path */
        assert(tM[q]<=tM[Q-1]+lambda_i+rho-1.0 + ringbc_lag);
      } else /*if (wormbc)*/ {
	/* Property 5: cell Q-1 is on the critical path */
       assert(tM[q]<=tM[Q-1]+rho-1.0 + ringbc_lag);                  
      }  
    }        	        	
    
    if (verbosity>=4) 
      printf("\n");    
    if (verbosity>=2) {    
      printf("(tC:)");
      for (q=0; q < Q; q++) 
        printf("%9.2f, ", tC[q]);    	
      printf("\n");
    }  
    if (verbosity>=3) {        	
      printf("(tM:)");
      for (q=0; q < Q; q++) 
        printf("%9.2f, ", tM[q]);   	
      printf("\n");
    }
  } /* for (i..) */  
  
print_tv();
 
printf("\n");
printf("max buffered msgs (# msgs, itn #, gap till next msg read)\n");
for (q = 0; q < Q; q++)
  printf(" %d %2d %6.2f,", jm[q]+1, pjm[q], djm[q]);
printf("\n");
printf("msg arrival (itn #, max gap until read of prev. itn)\n");
for (q = 0; q < Q; q++)
  printf("   %2d %6.2f,", pjm_[q], djm_[q]);
printf("\n");           
printf("accumulated stalls on msg receipts\n");
for (q = 0; q < Q; q++)
  printf("      %6.2f,", tS[q]);
printf("\n");           
printf("total time=%d, residual stalls at cell %d: %G\n", (int) tM[Q-1], Q-1, 
	tS[Q-1] - (!wormbc? Q*n*kappa: 0.0));

exit(0);
} /* main() */           


