"""
Bucketsort implementation using linked lists

Copyright 2026, University of Freiburg.

Philipp Schneider <philipp.schneider@cs.uni-freiburg.de>
Marc Fuchs <marc.fuchs@cs.uni-freiburg.de>
"""

import math  # noqa
from random import randint
import time
import matplotlib.pyplot as plt  # noqa
import statistics

import BucketSort  # noqa


def radix_sort(array, k):
    '''
    Implements the radix sort algorithm to sort
        data elements with keys in range(k+1)

    Args:
        array: array of data elements
        k: largest key

    >>> array = [10-i for i in range(10)]
    >>> radix_sort(array, 100)
    >>> str(array)
    '[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]'
    '''
    pass


#################################################################
def is_sorted(lst):
    """
        Method that checks whether or not a list is sorted

        Args:
            lst: the list that will be checked
        Returns:
        bool:    True if the list is sorted, false otherwise.
        """
    return all(lst[i] <= lst[i+1] for i in range(len(lst)-1))


def sort_performance(algorithm=radix_sort, size=2*10**4, initial_k=2*10**4,
                     k_increment=2*10**4, max_k=(12*(10**5))+1, repetitions=2):
    """
    Method that outputs array size and elapsed time for sorting.

    Args:
        size:           number of elements in the arrays
        initial_k:      initial value for the largest value
        k_increment:    increment for the largest value
        max_k:          upperbound for the largest number
        repetitions:    number of repetitions per k
    Returns:
        results:        list with the running times of all executions
    Raises:
        Exception:      when sorting was not successful
    """
    results = []
    for k in range(initial_k, max_k, k_increment):
        for _ in range(repetitions):
            execution_results = []
            array = [randint(0, k) for _ in range(size)]

            start_time = time.time()
            algorithm(array, k)
            run_time = (time.time() - start_time) * 1000

            if not is_sorted(array):
                raise Exception("list not sorted successfully")
            execution_results.append(run_time)
        results.append(statistics.mean(execution_results))
        print("{}\t{}\t{:.1f}".format(size, k, results[-1]))
    return results


def plot_results(results_bs, results_rs, initial_k, max_k, k_increment):
    scale = [x for x in range(initial_k, max_k, k_increment)]
    # fig = plt.figure(1)
    plt.xlabel("max key k")
    plt.ylabel("time in ms")
    plt.scatter(scale, results_rs, marker='x', s=5, label="radix")
    plt.scatter(scale, results_bs, marker='o', s=5, label="bucket")
    plt.legend()

    # plt.savefig("./plot.pdf")
    plt.show()


if __name__ == "__main__":
    # these values cause a long runtime, but make nice plots
    n = 10**4
    initial_k = 2*10**4
    k_increment = 2*10**4
    max_k = (12*(10**5))+1
    repetitions = 2
    print("Compute performance of BucketSort...")
    results_bs = sort_performance(BucketSort.bucket_sort, n, initial_k,
                                  k_increment, max_k, repetitions)

    print("Compute performance of RadixSort...")
    results_rs = sort_performance(radix_sort, n, initial_k, k_increment, max_k,
                                  repetitions)

    print("Plot...")
    plot_results(results_bs, results_rs, initial_k, max_k, k_increment)
    print("Done.")
