/* bstree - test code for binary search trees
 *
 * THIS IS A DRAFT VERSION!  ***************<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
 *
 * Brendan McKay, September 2007
 *
 * Usage:  bstree keys mixes replications
 *
 *    keys  =  number of different keys
 *    mixes =  amount of randomisation of the keys
 *    replications = number of tests
 *
 *  The keys 1..<keys> in order are first mixed by swapping
 *  <mixes> pairs of values, chosen at random.  Then for each key
 *   (1) insert the key into a binary search tree
 *   (2) search for a key chosen randomly from the previously inserted keys
 *   (3) For keys 2,4,6,...,
 *       delete a key chosen randomly from the previously inserted keys
 *  (In steps (2) and (3), the key might have been deleted already.)
 *
 *  This process is repeated <replications> times.  Then a measure of the
 *  average cost of insertions, lookups and deletions is written.
 */

#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <limits.h>


/* Define the type of a node in the tree */
typedef struct nodestruct
{
   struct nodestruct *left,*right;
   int key;
} node;

/* Boolean type */
typedef enum { false=0, true=1 } boolean;

static unsigned long counter;   /* Used to count key comparisons */

#define MAXKEYS 10000
static int keylist[MAXKEYS];

/************************************************************************/

node*
newnode(void)
/* Allocate a new node with error checking.  Use free( ) to free it.
 * The left and right fields are initialised to NULL.
 */
{
    node *ptr;

    ptr = (node*)malloc(sizeof(node));
    if (ptr == NULL)
    {
        fprintf(stderr,"Error: malloc failed in newnode()\n");
        exit(1);
    }

    ptr->left = ptr->right = NULL;
    return ptr;
}

/************************************************************************/

void
printtree(node *root, int level)
/* Print a tree sideways to stdout, left child first.
 * The first node is at the given level. */
{    
    int i;

    if (root != NULL)
    {
        printtree(root->left,level+1);
        for (i = 0; i < level; ++i)
            printf("-  ");
        printf("%d\n",root->key);
        printtree(root->right,level+1);
    }
}

/************************************************************************/

boolean
checknode(node *ptr)
/* auxiliary procedure used by checktree().
 * It is known that the argument is not NULL.  */
{
    boolean ok;

    ok = true;

    if (ptr->left != NULL)
    {
        if (ptr->left->key >= ptr->key)
        {
            fprintf(stderr,
               "Error: the key %d is in the left child of the key %d\n",
               ptr->left->key,ptr->key);
            ok = false;
        }
     /* parent:
        if (ptr->left->parent != ptr)
        {
            fprintf(stderr,
                "Error: the node with key %d has an invalid parent pointer\n",
                ptr->left->key);
            ok = false;
        }
      */
        if (!checknode(ptr->left)) ok = false;
    }

    if (ptr->right != NULL)
    {
        if (ptr->right->key <= ptr->key)
        {
            fprintf(stderr,
               "Error: the key %d is in the right child of the key %d\n",
               ptr->right->key,ptr->key);
            ok = false;
        }
     /* parent:
        if (ptr->right->parent != ptr)
        {
            fprintf(stderr,
                "Error: the node with key %d has an invalid parent pointer\n",
                ptr->right->key);
            ok = false;
        }
      */
        if (!checknode(ptr->right)) ok = false;
    }

    return ok;
}

boolean
checktree(node *root)
/* Check the validity of the tree with the given root. */
{
    boolean ok;

    if (root == NULL) return true;

    ok = checknode(root);
     
 /* parent:
    if (parent(root) != NULL)
    {
        fprintf(stderr,"Error: root should have NULL parent pointer\n");
        ok = false;
    }
  */

    return ok;
}

/************************************************************************/

void
newtree(node **root)
/* Initialise new empty tree.  The argument is the address of a
 * pointer which will be used to point to the root. */
{
    *root = NULL;
}

/************************************************************************/

void
freetree(node *root)
/* Free all the nodes in a tree. */
{
    if (root != NULL)
    {
        freetree(root->left);
        freetree(root->right);
        free(root);
    }
}

/************************************************************************/

boolean
ispresent(node **rootptr, int key)
/* Test if the key is present in the tree; return true or false */
{
    node *ptr;

    ++counter;

    ptr = *rootptr;
    while (ptr != NULL)
    {
        ++counter;
        if (ptr->key == key)
            return true;
        else if (ptr->key > key)
            ptr = ptr->left;
        else
            ptr = ptr->right;
    }

    return false;
}

/************************************************************************/

void
insertkey(node **rootptr, int key)
/* Insert the key into the tree unless it is already present.
 * The root pointer might change during this operation.
 */
{
    node *ptr,*newptr;

    ++counter;

    ptr = *rootptr;
    
    if (ptr == NULL)   /* Special case: first node in the tree */
    {
        newptr = newnode();
        newptr->key = key;
        *rootptr = newptr;
        return;
    }

    while (true)
    {
        ++counter;
        if (ptr->key == key)        /* Key already present */
           return;
        else if (ptr->key > key)
        {
            if (ptr->left == NULL)  /* New left child */
            {
                newptr = newnode();
                ptr->left = newptr;
                newptr->key = key;
                return;
            }
            else
                ptr = ptr->left;
        }
        else
        {
            if (ptr->right == NULL)  /* New right child */
            {
                newptr = newnode();
                ptr->right = newptr;
                newptr->key = key;
                return;
            }
            else
                ptr = ptr->right;
        }
    }
}

/************************************************************************/

void
deletekey(node **rootptr, int key)
/* Delete the key if it is present. */
{
    node *ptr,*par;
    node *save,*child;

    ++counter;

 /* First we search for the key.  If we find it, ptr will point to the
  * node containing the key and par will point to its parent (NULL if
  * there is no parent).  If it isn't present, ptr will be NULL. */

    ptr = *rootptr;
    par = NULL;
    while (ptr != NULL)
    {
        ++counter;
        if (ptr->key == key)
            break;
        else if (ptr->key > key)
        {
            par = ptr;
            ptr = ptr->left;
        }
        else
        {
            par = ptr;
            ptr = ptr->right;
        }
    }

    if (ptr == NULL) return;     /* Key not present */

 /* If the node we wish to delete has two children, we locate the
  * next node X in in-order and move its key into this one.  par
  * and ptr will be then be as if X is the node we want to delete,
  * which indeed it is. */

    if (ptr->left != NULL && ptr->right != NULL)
    {
        save = ptr;
        par = ptr;
        ptr = ptr->right;
        while (ptr->left != NULL)
        {
            ++counter;
            par = ptr;
            ptr = ptr->left;
        }
        save->key = ptr->key;
    }

 /* Now we want to delete node *ptr which has at most one child */

    if (ptr->left != NULL)
        child = ptr->left;
    else if (ptr->right != NULL)
        child = ptr->right;
    else
        child = NULL;

    if (par == NULL)             /* Case of deleting the root */
        *rootptr = child;
    else if (par->left == ptr)   /* Case of left child */
        par->left = child;
    else                         /* Case of right child */
        par->right = child;

    free(ptr);
}

/************************************************************************/

int
main(int argc, char *argv[])
{
    int keys,mixings,replications;
    int repl,i,j1,j2,temp;
    node *root;
    unsigned long lookupcount,insertcount,deletecount;
    double lookupav,insertav,deleteav;

    srandom((long)time(NULL));   /* Initialise random number generator */

    if (argc != 4 ||
        sscanf(argv[1],"%d",&keys) != 1 ||
        sscanf(argv[2],"%d",&mixings) != 1 ||
        sscanf(argv[3],"%d",&replications) != 1)
    {
        fprintf(stderr,"Usage: sptree keys mixings replications\n");
        exit(1);
    }

    if (keys < 1 || keys > MAXKEYS)
    {
        fprintf(stderr,"Error: must have 1..%d keys\n",MAXKEYS);
        exit(1);
    }

    if (replications <= 0)
    {
        fprintf(stderr,"Error: must have at least one replication\n");
        exit(1);
    }

    lookupcount = insertcount = deletecount = 0;

    for (repl = 0; repl < replications; ++repl)
    {
        for (i = 0; i < keys; ++i)
            keylist[i] = i;
        for (i = 0; i < mixings; ++i)
        {
            j1 = random() % keys;
            j2 = random() % keys;
            temp = keylist[j1];
            keylist[j1] = keylist[j2];
            keylist[j2] = temp;
        }

        newtree(&root);
        for (i = 0; i < keys; ++i)
        {
            counter = 0;
            insertkey(&root,keylist[i]);
            insertcount += counter;
            counter = 0;
            ispresent(&root,keylist[random()%(i+1)]);
            lookupcount += counter;
            counter = 0;
            if (i % 2 == 0)  /* Only on odd occasions */
            {
                deletekey(&root,keylist[random()%(i+1)]);
                deletecount += counter;
            }
        }
	checktree(root);
        freetree(root);
    }

    insertav = (double)insertcount / (keys*replications);
    lookupav = (double)lookupcount / (keys*replications);
    deleteav = (double)deletecount / (((keys+1)/2)*replications);

    printf("Average work per insertion = %8.4f\n",insertav);
    printf("Average work per lookup    = %8.4f\n",lookupav);
    printf("Average work per deletion  = %8.4f\n",deleteav);

    return 0;
}
