#include "future.h"
#include "coroutine.h"
#include "task.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <errno.h>
#include <math.h>
#include <string.h>
#include "timespec_utils.h"


Future_vfptrs_t Future_vfptrs = {
    &Future_dtor,
    &_Future_Await,
    &_Future_SetResult
};

typedef struct Future_WatcherSpec {
    Future_Watcher watcher;
    void *me;
} Future_WatcherSpec;

void Future_ctor(Future *fut){
    fut->vfptrs = &Future_vfptrs;
    _Cor_Mutex_ctor(&fut->mutex);
    fut->state = Future_State_Waiting;
    fut->value = NULL;
    fut->canceled = false;
    fut->watchers = NULL;
    fut->nwatchers = 0;
    fut->maxwatchers = 0;
}


Future *Future_New(){
    Future *fut = malloc(sizeof(Future));
    Future_ctor(fut);
    return fut;
}


void Future_dtor(
    Future *fut
){
    _Cor_Mutex_Lock(&fut->mutex);
    assert(fut->nwatchers == 0);
    free(fut->watchers);
    _Cor_Mutex_Unlock(&fut->mutex);
    _Cor_Mutex_dtor(&fut->mutex);
}


void Future_Delete(
    Future *fut
){
    fut->vfptrs->dtor(fut);
    free(fut);
}


void Future_AddWatcher(
    Future *fut,
    Future_Watcher watcher,
    void *watcher_me
){
    _Cor_Mutex_Lock(&fut->mutex);
    if (fut->state == Future_State_Done) {
        _Cor_Mutex_Unlock(&fut->mutex);
        watcher(watcher_me, fut);
    } else {
        if (fut->nwatchers >= fut->maxwatchers) {
            fut->maxwatchers = (fut->maxwatchers == 0) ? 4 : fut->maxwatchers * 2;
            fut->watchers = realloc(fut->watchers, fut->maxwatchers * sizeof(Future_WatcherSpec));
            assert(fut->watchers);
        }
        fut->watchers[fut->nwatchers].watcher = watcher;
        fut->watchers[fut->nwatchers].me = watcher_me;
        fut->nwatchers++;
        _Cor_Mutex_Unlock(&fut->mutex);
    }
}


void Future_RemoveWatcher(
    Future *fut,
    Future_Watcher watcher,
    void *watcher_me
){
    _Cor_Mutex_Lock(&fut->mutex);
    for (int i = 0; i < fut->nwatchers; i++) {
        if (fut->watchers[i].watcher == watcher && fut->watchers[i].me == watcher_me) {
            fut->nwatchers--;
            if (i < fut->nwatchers) {
                fut->watchers[i] = fut->watchers[fut->nwatchers];
            }
            break;
        }
    }
    _Cor_Mutex_Unlock(&fut->mutex);
}


static void _Future_Ready(
    Future *fut
){
    fut->state = Future_State_Done;

    // Take note of watchers list, and reset it
    Future_WatcherSpec *watchers = fut->watchers;
    int nwatchers = fut->nwatchers;
    fut->watchers = NULL;
    fut->nwatchers = 0;
    _Cor_Mutex_Unlock(&fut->mutex);

    // notify those watchers
    for (int i = 0; i < nwatchers; i++) {
        watchers[i].watcher(watchers[i].me, fut);
    }
    free(watchers);
}


void Future_SetResult(
    Future *fut,
    bool canceled,
    void *value
){
    fut->vfptrs->set_result(fut, canceled, value);
}


Coroutine_Err Future_GetResult(
    Future *fut,
    void **res
){
    if(fut->state != Future_State_Done){
        return Coroutine_Err_WrongState;
    }
    if (res){
        *res = fut->value;
    }
    return fut->canceled ? Coroutine_Err_Canceled : Coroutine_OK;
}


typedef struct future_bits {
    Future *fut;
    Coroutine *cor;
} future_bits;


static void future_complete(
    void *me,
    Future *fut
){
    (void)fut;
    Coroutine *cor = (Coroutine *)me;
    Coroutine_NS(Continue)(cor, NULL, false);
}


static void on_yield_for_future(
    void *me
){
    future_bits *bits = (future_bits *)me;
    Future_AddWatcher(bits->fut, future_complete, bits->cor);
}


void _Future_Await(
    Future *fut
){
    future_bits bits;
    bits.fut = fut;
    bits.cor = Coroutine_NS(GetActive)();
    Task *my_task = current_task;
    current_task = NULL;
    my_task->awaiting_future = fut;
    if (my_task->canceled){
        Future_SetResult(fut, true, my_task->cancel_value);
    }
    Coroutine_NS(Yield)(NULL, on_yield_for_future, &bits);
    my_task->awaiting_future = NULL;
    current_task = my_task;
}


void _Future_SetResult(
    Future *fut,
    bool canceled,
    void *res
){
    _Cor_Mutex_Lock(&fut->mutex);
    if(fut->state == Future_State_Waiting){
        fut->canceled = canceled;
        fut->value = res;
        _Future_Ready(fut);
    } else {
        _Cor_Mutex_Unlock(&fut->mutex);
    }
}


Coroutine_Err Future_Await(
    Future *fut,
    void **res
){
    fut->vfptrs->await(fut);
    if (res){
        *res = fut->value;
    }
    return fut->canceled ? Coroutine_Err_Canceled : Coroutine_OK;
}
