#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <assert.h>
double distance(double x1, double y1, double x2, double y2);
int main(int argc, char *argv[])
{
  int natom;
  double rside,rcut;
  struct coord {
    double xx,yy;
  };
  struct coord *pcoord;
  double xi,yi,xj,yj;
  double d[9],ddd,dmin,dxx,dyy;
  int i,j,k;
  // rside is the length of the side of the 2d box
  // natom is the number of atoms in the box

  if (argc != 3) {
    printf(" %s rside natom \n",argv[0]);
    return -1;
  }
    else {
      rside  = (double) strtod(argv[1], NULL);
      natom  = atoi(argv[2]);
  }
  // our cutoff is 1/2 the box size (has to be this or less)
  rcut = rside/2.0;
  pcoord = (struct coord*) malloc(natom*sizeof(struct coord));

  // set up some random coordinates
  srand48(123456789);
  for (i=0;i<natom;i++){
    pcoord[i].xx=rside*drand48();
    pcoord[i].yy=rside*drand48();
    printf("coord %lf %lf\n",pcoord[i].xx,pcoord[i].yy);
  }

  // loop over atom-atom interactions
  for (i=1;i<natom;i++){
    xi=pcoord[i].xx;
    yi=pcoord[i].yy;
    for (j=0;j<i;j++){
      xj=pcoord[j].xx;
      yj=pcoord[j].yy;

      // how to find shortest distance between i and any of the 9 j images
      dxx=xi-xj;
      dyy=yi-yj;
      dxx = (dxx>0.0) ? dxx-2*rcut*floor(dxx/rcut):dxx+2*rcut*floor(-dxx/rcut);
      dyy = (dyy>0.0) ? dyy-2*rcut*floor(dyy/rcut):dyy+2*rcut*floor(-dyy/rcut);
      ddd= sqrt(dxx*dxx+dyy*dyy);

      // Check this by explicitly considering images
      d[0]=distance(xi,yi,xj,yj);
      d[1]=distance(xi,yi,xj+rside,yj      );
      d[2]=distance(xi,yi,xj      ,yj+rside);
      d[3]=distance(xi,yi,xj+rside,yj+rside);
      d[4]=distance(xi,yi,xj-rside,yj      );
      d[5]=distance(xi,yi,xj      ,yj-rside);
      d[6]=distance(xi,yi,xj-rside,yj-rside);
      d[7]=distance(xi,yi,xj+rside,yj-rside);
      d[8]=distance(xi,yi,xj-rside,yj+rside);

      // which is the smallest in d[]
      dmin=d[0];
      for (k=0;k<9;k++){
	printf("%lf ",d[k]);
	dmin = (d[k] < dmin) ? d[k] : dmin;
      }
      printf("min_d %lf ddd %lf\n",dmin,ddd);

      // abort if we got it wrong
      if (fabs((dmin-ddd)/ddd) > 1.0e-10)assert("ARGH ERROR!\n");
    }
  }
  return 0;
}
double distance(double x1, double y1, double x2, double y2){
  double dist;
  dist = sqrt((x1-x2)*(x1-x2)+(y1-y2)*(y1-y2));
  return dist;
}

