"""
RadixSort implementation using BucketSort

Copyright 2024, University of Freiburg.

Philipp Schneider <philipp.schneider@cs.uni-freiburg.de>
Marc Fuchs <marc.fuchs@cs.uni-freiburg.de>
"""

from random import randint
import matplotlib.pyplot as plt
import math
import statistics
import time

import BucketSort


def radix_sort(array, k):
    '''
    Implements the radix sort algorithm to sort
        data elements with keys in range(k+1)

    Args:
        array: array of data elements
        k: largest key

    >>> array = []
    >>> radix_sort(array, 10)
    >>> str(array)
    '[]'
    >>> array = [12, 8, 1000, 999, 0, 100]
    >>> radix_sort(array, 1000)
    >>> str(array)
    '[0, 8, 12, 100, 999, 1000]'
    >>> array = [10-i for i in range(10)]
    >>> radix_sort(array, 100)
    >>> str(array)
    '[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]'
    >>> array = [100-i for i in range(100)]
    >>> result = [i+1 for i in range(100)]
    >>> radix_sort(array, 100)
    >>> str(array) == str(result)
    True
    '''
    base = 10

    # compute the maximum number of digits of an element
    max_digits = math.floor(math.log(k, base)) + 1

    # iterate through the digits and sort using BucketSort
    for i in range(max_digits+1):
        key = lambda x: (x % base**(i+1)) // base**i # noqa
        BucketSort.bucket_sort(array, base, key)


def is_sorted(lst):
    """
        Method that checks whether or not a list is sorted

        Args:
            lst: the list that will be checked
        Returns:
        bool:    True if the list is sorted, false otherwise.
        """
    return all(lst[i] <= lst[i+1] for i in range(len(lst)-1))


def radix_sort_performance(size=2*10**4, initial_k=2*10**4,
                           k_increment=2*10**4,
                           max_k=(12*(10**5))+1,
                           repetitions=10):
    """
    Method that outputs array size and elapsed time for sorting.

    Args:
        size:           number of elements in the arrays
        initial_k:      initial value for the largest value
        k_increment:    increment for the largest value
        max_k:          upperbound for the largest number
        repetitions:    number of repetitions per k
    Returns:
        results:        list with the running times of all executions
    Raises:
        Exception:      when quicksort did not sort the array correctly
    """
    results = []
    for k in range(initial_k, max_k, k_increment):
        for _ in range(repetitions):
            execution_results = []
            array = [randint(0, k) for _ in range(size)]

            start_time = time.time()
            radix_sort(array, k)
            run_time = (time.time() - start_time) * 1000

            if not is_sorted(array):
                raise Exception("list not sorted successfully")
            execution_results.append(run_time)
        results.append(statistics.mean(execution_results))
        print("{}\t{}\t{:.1f}".format(size, k, results[-1]))
    return results


if __name__ == "__main__":
    # these values cause a long runtime, but make for nice plots
    n = 10**4
    initial_k = 2*10**4
    k_increment = 2*10**4
    max_k = (12*(10**5))+1
    repetitions = 3

    results_bs = BucketSort.bucket_sort_performance(n, initial_k,
                                                    k_increment,
                                                    max_k, repetitions)
    results_rs = radix_sort_performance(n, initial_k,
                                        k_increment, max_k, repetitions)

    scale = [x for x in range(initial_k, max_k, k_increment)]
    fig = plt.figure(1)
    plt.xlabel("max key k")
    plt.ylabel("time in ms")
    plt.plot(scale, results_rs, label="radix")
    plt.plot(scale, results_bs, label="bucket")
    plt.legend()

    #  uncomment to make new pictures
    plt.savefig("./plot.pdf")
    plt.show()


if __name__ == "__main__":
    radix_sort_performance()
