import random
import time
import matplotlib.pyplot as plt


def selection_sort(array):
    """
       Sorts the given array using selection sort.

       Args:
           array (list): The array to be sorted.

       Tests:
           >>> array = [12, 11, 13, 5, 6, 7]
           >>> selection_sort(array)
           >>> array
           [5, 6, 7, 11, 12, 13]
    """
    # Your code here...
    pass


def merge_sort(array):
    """
       Sorts the given array using merge sort.

       Args:
           array (list): The array to be sorted.

       Tests:
           >>> array = [12, 11, 13, 5, 6, 7]
           >>> merge_sort(array)
           >>> array
           [5, 6, 7, 11, 12, 13]
    """
    # Your code here...
    pass


def mergesort_recursive(array, start, end, tmp):
    """ Sorts the subarray array[start .. end-1] by using
            the merge sort algorithm.

        Args:
            array (list): The underlying array.
            start (int): Starting index of the subarray to be sorted.
            end (int): Ending index (exclusive) of the subarray to be sorted.
            tmp (list): A temporary buffer to be used during recursion.

        Tests:
        >>> array = [3, 6, 1, 7, 9, 5, 2, 6, 0, 8]
        >>> tmp = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
        >>> mergesort_recursive(array, 4, 5, tmp)
        >>> array
        [3, 6, 1, 7, 9, 5, 2, 6, 0, 8]
        >>> mergesort_recursive(array, 3, 8, tmp)
        >>> array
        [3, 6, 1, 2, 5, 6, 7, 9, 0, 8]

    """
    # your code here
    pass

###############################################################################
###############################################################################
###############################################################################


def time_measurement(algorithm, size_range):
    """
       Performs timed tests and returns time measurements.

       Args:
           algorithm: The sorting algorithm to be tested.
           size_range (range): specifies the input sizes for the measurements.

       Returns:
           list: list of execution times.

    """
    results = []
    # Loop from 'skip' up to 'range_n', stepping by 'skip'
    for n in size_range:
        random_array = [random.randint(0, 10000) for _ in range(n)]
        start_time = time.time()
        algorithm(random_array)
        end_time = time.time()
        results.append(end_time - start_time)

        # Verify that the array was actually sorted
        if not all([random_array[i] <= random_array[i+1] for i in range(n-1)]):
            raise Exception("Didn't Sort Correctly!")

    return results


def plot_results(x_values, selection_times, merge_times, use_log_scale=False):
    """
    Plots the execution times of the sorting algorithms.

    Args:
        x_values (list): The input sizes (n) tested.
        selection_times (list): Execution times for Selection Sort.
        merge_times (list): Execution times for Merge Sort.
        use_log_scale (bool): If True, plots the Y-axis on a logarithmic scale.
    """
    plt.figure(figsize=(10, 6))

    # Plot both algorithms
    plt.plot(x_values, selection_times, label='Selection Sort',
             color='red', linewidth=2)
    plt.plot(x_values, merge_times, label='Merge Sort',
             color='blue', linewidth=2)

    # Adjust axis and titles based on the scale choice
    if use_log_scale:
        plt.yscale('log')
        plt.ylabel('Execution Time (seconds) [Log Scale]')
        plt.title('Time Complexity (Logarithmic Scale)')
    else:
        plt.ylabel('Execution Time (seconds)')
        plt.title('Time Complexity')

    # Format the rest of the chart
    plt.xlabel('Input Size (n)')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)

    # Display the plot
    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    size_range = range(100, 10001, 100)

    print("Measuring Selection Sort... (This might take some time)")
    selection_times = time_measurement(selection_sort, size_range)
    print("Done...")

    print("Measuring Merge Sort...(This might take some time)")
    merge_times = time_measurement(merge_sort, size_range)
    print("Done...")

    # Generate the x-axis values to match the 'n' sizes used in the loop
    x_values = list(size_range)

    # Create the plot
    plot_results(x_values, selection_times, merge_times, use_log_scale=False)
    # To plot a logarithmic scale, uncomment the following line:
    # plot_results(x_values, selection_times, merge_times, use_log_scale=True)
