#include <stdio.h>

typedef struct node {
	struct node *parent;
	struct node *left;
	struct node *right;
	
	int value;
} Node;

typedef struct tree {
	Node *root;
} Tree;

void findPlace(Node *_toAdd, Node *_current) {
	// set this node to the parent so we dont have to do it later (may change next recursion)
	_toAdd->parent = _current;
	
	// compare value of node we're adding to current node
	if(_toAdd->value < _current->value)
		// add the element to the left subtree
		if(_current->left == NULL)			// if no left subtree yet, add it right here
			_current->left = _toAdd;
		else										// otherwise continue recursion to left subtree
			findPlace(_toAdd, _current->left);
	else
		// add the element to the right subtree
		if(_current->right == NULL)		// if no right subtree yet, add it right here
			_current->right = _toAdd;
		else										// otherwise, continue recursing to the right subtree
			findPlace(_toAdd, _current->right);
}

Node *addNode(int _val, Tree* _tree) {
	Node *newNode = (Node *)malloc(sizeof(Node));
	newNode->value = _val;
	
	// if the tree is empty, this is the new root
	if(_tree->root == NULL)
		_tree->root = newNode;
	
	// otherwise, we need to find the right place to put it
	else
		findPlace(newNode, _tree->root);
	
	// return reference to the node
	return newNode;
}

// helper function used to find place the node would be recursively
Node *_findNode(int _val, Node *_current) {
	if(_current == NULL)
		return NULL;
	
	if(_current->value == _val)
		return _current;
	else if(_val < _current->value)
		return _findNode(_val, _current->left);
	else
		return _findNode(_val, _current->right);
}

// search the tree to see if the elment is in it, returns the first Node it is in or NULL if not in it
Node *contains(int _val, Tree *_tree) {
	Node *current = _tree->root;	// start at the root of the tree
	return _findNode(_val, current);
}

// Removes the selected node from the tree, returns value in it
int removeNode(Node *_toRemove, Tree *_tree)
{
	// save the int value for later
	int toReturn = _toRemove->value;
	
	// Find the left most node in the right sub-tree of the given node
	Node *previous = _tree->root->right;
	Node *current = _tree->root->right;
	
	// Go through the right sub-tree until we reach the left-most node
	while(current != NULL) 	{ previous = current; current = current->left; }
	
	// If there was no right node, then just promote the left subtree
	if(previous == NULL)	previous = _toRemove->left;
	
	// If we have a node to replace the removed node, figure out if it was a left
	// or right child and re-direct parents (if they exist)
	if(_tree->root->right == NULL)
	{
		_tree->root = previous;
	}
	else
	{
		if(previous != NULL && previous->parent->right == previous)
			previous->parent->right = previous->right;
		else if(previous != NULL && previous->parent->left == previous)
			previous->parent->left = previous->right;
		
		// At this point "previous" contains the node that's going to replace the node
		// we are removing.
		
		// redirect the pointers of the removed nodes children
		previous->left = _toRemove->left;
		previous->right = _toRemove->right;
		
		if(_toRemove == _tree->root)
			_tree->root = previous;
		else
			// If it is a left child
			if(_toRemove->parent->left == _toRemove)
				_toRemove->parent->left = previous;
			// otherwise it is a right child
			else
				_toRemove->parent->right = previous;
	}
		
	// clean up the memory
	free(_toRemove);
	
	// return the value in the given node
	return toReturn;
}

// pre-order traversal of the tree starting at _root
void preOrder(Node *_root)
{
	printf("%d ", _root->value);
	if(_root->left != NULL)
		preOrder(_root->left);
	if(_root->right != NULL)
		preOrder(_root->right);
}

// in-order traversal of the tree starting at _root
void inOrder(Node *_root)
{
	if(_root->left != NULL)
		inOrder(_root->left);
	printf("%d ", _root->value);
	if(_root->right != NULL)
		inOrder(_root->right);
}

// post-order traversal of the tree starting at _root
void postOrder(Node *_root)
{
	if(_root->left != NULL)
		postOrder(_root->left);
	if(_root->right != NULL)
		postOrder(_root->right);
	printf("%d ", _root->value);
}


int main() {
	// create a tree to fool around with
	Tree *myTree = (Tree *)malloc(sizeof(Tree));
	
	// Add some values to the tree
	addNode(7, myTree);
	addNode(3, myTree);
	addNode(7, myTree);
	addNode(11, myTree);
	addNode(6, myTree);
	addNode(8, myTree);
	addNode(12, myTree);
	addNode(0, myTree);
	addNode(2, myTree);
	addNode(9, myTree);
	
	/* At this point, the tree looks like this:

                        7
                     /     \
	            3       7
	          /  \        \
	         0    6        11
	          \           /   \
	           2         8     12
	                      \
	                       9
	*/
	
	// Display contents of the tree
	printf("Pre-Order Traversal: "); preOrder(myTree->root); printf("\n");
	printf("In-Order Traversal: "); inOrder(myTree->root); printf("\n");
	printf("Post-Order Traversal: "); postOrder(myTree->root); printf("\n");
	
	// Check if certain values are in the tree
	if(contains(4, myTree)) printf("Tree Contains 4\n"); else printf("Tree does NOT contain 4\n");
	if(contains(5, myTree)) printf("Tree Contains 5\n"); else printf("Tree does NOT contain 5\n");
	if(contains(7, myTree)) printf("Tree Contains 7\n"); else printf("Tree does NOT contain 7\n");
	
	// Remove 7 and check again (there are 2 7's in the tree)
	Node *aNode = contains(7, myTree);
	int val = removeNode(aNode, myTree);
	if(contains(7, myTree)) printf("Tree Contains 7\n"); else printf("Tree does NOT contain 7\n");
	
		
	/* At this point, the tree looks like this:

                        7
                     /     \
	            3       11
	          /  \     /  \
	         0    6   8    12
	          \        \
	           2         9
	*/
	
	// Display contents of the tree
	printf("Pre-Order Traversal: "); preOrder(myTree->root); printf("\n");
	printf("In-Order Traversal: "); inOrder(myTree->root); printf("\n");
	printf("Post-Order Traversal: "); postOrder(myTree->root); printf("\n");

	aNode = contains(7, myTree);
	val = removeNode(aNode, myTree);
	if(contains(7, myTree)) printf("Tree Contains 7\n"); else printf("Tree does NOT contain 7\n");
		
	/* At this point, the tree looks like this:

                        8
                     /    \
	            3      11
	          /  \    /  \
	         0    6  9    12
	          \       
	           2      
	*/
	
	// Display contents of the tree
	printf("Pre-Order Traversal: "); preOrder(myTree->root); printf("\n");
	printf("In-Order Traversal: "); inOrder(myTree->root); printf("\n");
	printf("Post-Order Traversal: "); postOrder(myTree->root); printf("\n");

	aNode = contains(12, myTree);
	removeNode(aNode, myTree);
	if(contains(12, myTree)) printf("Tree Contains 12\n"); else printf("Tree does NOT contain 12\n");

		
	/* At this point, the tree looks like this:

                        8
                     /    \
	            3      11
	          /  \     / 
	         0    6   9    
	          \        
	           2        
	*/	
	
	// Display contents of the tree
	printf("Pre-Order Traversal: "); preOrder(myTree->root); printf("\n");
	printf("In-Order Traversal: "); inOrder(myTree->root); printf("\n");
	printf("Post-Order Traversal: "); postOrder(myTree->root); printf("\n");

	return 0;
}
