"""
Copyright 2026, University of Freiburg.

Marc Fuchs <marc.fuchs@cs.uni-freiburg.de>
"""

import random
import time
import matplotlib.pyplot as plt


def selection_sort(array):
    """
       Sorts the given array using selection sort.

       Args:
           array (list): The array to be sorted.

       Tests:
            >>> array = [12, 11, 13, 5, 6, 7]
            >>> selection_sort(array)
            >>> array
            [5, 6, 7, 11, 12, 13]

            >>> array = [5, 1, 3, 8]
            >>> selection_sort(array)
            >>> array
            [1, 3, 5, 8]

            >>> array = [11, 3, 5, 8]
            >>> selection_sort(array)
            >>> array
            [3, 5, 8, 11]

            >>> array = []
            >>> selection_sort(array)
            >>> array
            []

            >>> array = [12]
            >>> selection_sort(array)
            >>> array
            [12]

            >>> array = [5, 1, 8, 3, 4, 0]
            >>> selection_sort(array)
            >>> array
            [0, 1, 3, 4, 5, 8]

            >>> array = [5, 1, 8, 7, 3, 4, 0]
            >>> selection_sort(array)
            >>> array
            [0, 1, 3, 4, 5, 7, 8]

            >>> array = [2, 2, 2, 2, 1]
            >>> selection_sort(array)
            >>> array
            [1, 2, 2, 2, 2]

            >>> array = [-2, 0, -5, 2, -1]
            >>> selection_sort(array)
            >>> array
            [-5, -2, -1, 0, 2]
    """
    n = len(array)
    for i in range(n-1):
        # find the minimum in array[i .. n-1]
        min_value = array[i]
        min_index = i

        for j in range(i + 1, n):
            if array[j] < min_value:
                min_value = array[j]
                min_index = j

        # swap array[i] with array[min_index]
        array[i], array[min_index] = array[min_index], array[i]


def merge_sort(array):
    """
       Sorts the given array using merge sort.

       Args:
           array (list): The array to be sorted.

       Tests:
            >>> array = [12, 11, 13, 5, 6, 7]
            >>> merge_sort(array)
            >>> array
            [5, 6, 7, 11, 12, 13]

            >>> array = [5, 1, 3, 8]
            >>> merge_sort(array)
            >>> array
            [1, 3, 5, 8]

            >>> array = [11, 3, 5, 8]
            >>> merge_sort(array)
            >>> array
            [3, 5, 8, 11]

            >>> array = []
            >>> merge_sort(array)
            >>> array
            []

            >>> array = [12]
            >>> merge_sort(array)
            >>> array
            [12]

            >>> array = [5, 1, 8, 3, 4, 0]
            >>> merge_sort(array)
            >>> array
            [0, 1, 3, 4, 5, 8]

            >>> array = [5, 1, 8, 7, 3, 4, 0]
            >>> merge_sort(array)
            >>> array
            [0, 1, 3, 4, 5, 7, 8]

            >>> array = [2, 2, 2, 2, 1]
            >>> merge_sort(array)
            >>> array
            [1, 2, 2, 2, 2]

            >>> array = [-2, 0, -5, 2, -1]
            >>> merge_sort(array)
            >>> array
            [-5, -2, -1, 0, 2]
    """
    tmp = [None] * len(array)
    mergesort_recursive(array, 0, len(array), tmp)


def mergesort_recursive(array, start, end, tmp):
    """ Sorts the subarray array[start .. end-1] by using
            the merge sort algorithm.

        Args:
            array (list): The underlying array.
            start (int): Starting index of the subarray to be sorted.
            end (int): Ending index (exclusive) of the subarray to be sorted.
            tmp (list): A temporary buffer to be used during recursion.

        Tests:
        >>> array = [3, 6, 1, 7, 9, 5, 2, 6, 0, 8]
        >>> tmp = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        >>> mergesort_recursive(array, 4, 5, tmp)
        >>> array
        [3, 6, 1, 7, 9, 5, 2, 6, 0, 8]
        >>> mergesort_recursive(array, 3, 8, tmp)
        >>> array
        [3, 6, 1, 2, 5, 6, 7, 9, 0, 8]

    """
    if end - start > 1:
        middle = (start + end) // 2
        mergesort_recursive(array, start, middle, tmp)
        mergesort_recursive(array, middle, end, tmp)

        # Merging array[start .. middle-1] and array[middle .. end-1] into tmp
        i = start
        j = middle
        pos = i
        while pos < end:
            if i < middle and (j >= end or array[i] <= array[j]):
                tmp[pos] = array[i]
                pos += 1
                i += 1
            else:
                tmp[pos] = array[j]
                pos += 1
                j += 1
        for i in range(start, end):
            array[i] = tmp[i]


def time_measurement(algorithm, size_range):
    """
       Performs timed tests and returns time measurements.

       Args:
           algorithm: The sorting algorithm to be tested,
            i.e., Selectionsort or Mergesort.
           size_range (range): specifies the starting size and the step size
            between tests.

       Returns:
           list: list of execution times.

    """
    results = []
    # Loop from 'skip' up to 'range_n', stepping by 'skip'
    for n in size_range:
        random_array = [random.randint(0, 10000) for _ in range(n)]
        start_time = time.time()
        algorithm(random_array)
        end_time = time.time()
        results.append(end_time - start_time)

        # Verify that the array was actually sorted
        if not all([random_array[i] <= random_array[i+1] for i in range(n-1)]):
            raise Exception("Didn't Sort Correctly!")

    return results


def plot_results(x_values, selection_times, merge_times, use_log_scale=False):
    """
    Plots the execution times of the sorting algorithms.

    Args:
        x_values (list): The input sizes (n) tested.
        selection_times (list): Execution times for Selection Sort.
        merge_times (list): Execution times for Merge Sort.
        use_log_scale (bool): If True, plots the Y-axis on a logarithmic scale.
    """
    plt.figure(figsize=(10, 6))

    # Plot both algorithms
    plt.plot(x_values, selection_times, label='Selection Sort',
             color='red', linewidth=2)
    plt.plot(x_values, merge_times, label='Merge Sort',
             color='blue', linewidth=2)

    # Adjust axis and titles based on the scale choice
    if use_log_scale:
        plt.yscale('log')
        plt.ylabel('Execution Time (seconds) [Log Scale]')
        plt.title('Time Complexity (Logarithmic Scale)')
    else:
        plt.ylabel('Execution Time (seconds)')
        plt.title('Time Complexity')

    # Format the rest of the chart
    plt.xlabel('Input Size (n)')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)

    # Display the plot
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    size_range = range(100, 10001, 100)

    print("Measuring Selection Sort... (This might take some time)")
    selection_times = time_measurement(selection_sort, size_range)
    print("Done...")

    print("Measuring Merge Sort...(This might take some time)")
    merge_times = time_measurement(merge_sort, size_range)
    print("Done...")

    # Generate the x-axis values to match the 'n' sizes used in the loop
    x_values = list(size_range)

    # Create the plot
    plot_results(x_values, selection_times, merge_times, use_log_scale=False)
    plot_results(x_values, selection_times, merge_times, use_log_scale=True)
