#include "coroutine.h"
#include <assert.h>
#include <setjmp.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <pthread.h>
#include <stdlib.h>

#define COROUTINE_STACK_SIZE 16384
#define COROUTINE_STARTUP_STACK_SIZE 1024


static void *mustmalloc(size_t size){
    void *p = malloc(size);
    assert(p);
    return p;
}

#define New(type, ...) (type##_ctor((type *)mustmalloc(sizeof(type), ## __VA_ARGS__)))
#define Delete(ptr, type) ((ptr) ? (type##_dtor(ptr), free(ptr), (ptr) = NULL) : (void)0)

///////////////////////////////////////////////////////////////////////////////
// Semaphore built from mutex & condition variables...
//
// Using pthread.h (more widely available than the C standard library thread.h)
///////////////////////////////////////////////////////////////////////////////

typedef struct Semaphore {
    pthread_mutex_t mutex;
    pthread_cond_t cond;
    int count;
} Semaphore;

static void Semaphore_ctor(Semaphore *sem, int initial_count){
    sem->count = initial_count;
    int r = pthread_mutex_init(&sem->mutex, NULL);
    assert(r == 0);
    r = pthread_cond_init(&sem->cond, NULL);
    assert(r == 0);
}

static void Semaphore_dtor(Semaphore *sem){
    int r = pthread_mutex_destroy(&sem->mutex);
    assert(r == 0);
    r = pthread_cond_destroy(&sem->cond);
    assert(r == 0);
}

static void Semaphore_Claim(Semaphore *sem){
    int r = pthread_mutex_lock(&sem->mutex);
    assert(r == 0);
    while (sem->count <= 0) {
        r = pthread_cond_wait(&sem->cond, &sem->mutex);
        assert(r == 0);
    }
    sem->count--;
    r = pthread_mutex_unlock(&sem->mutex);
    assert(r == 0);
}

static void Semaphore_Release(Semaphore *sem){
    int r = pthread_mutex_lock(&sem->mutex);
    assert(r == 0);
    sem->count++;
    r = pthread_cond_broadcast(&sem->cond);
    assert(r == 0);
    r = pthread_mutex_unlock(&sem->mutex);
    assert(r == 0);
}

///////////////////////////////////////////////////////////////////////////////
// ...semaphore built from mutex
///////////////////////////////////////////////////////////////////////////////

///////////////////////////////////////////////////////////////////////////////
// 2-way linked lists...
//
// Brought inline here to avoid namespace polution
///////////////////////////////////////////////////////////////////////////////

typedef struct List_Link List_Link;
struct List_Link {
    List_Link *next;
    List_Link *prev;
};

typedef struct List_Head List_Head;
struct List_Head {
    union {
        struct {
            List_Link link;
            List_Link *filler;
        } fwd;
        struct {
            List_Link *filler;
            List_Link link;
        } back;
    };
};

static inline bool List_IsEmpty(const List_Head *list) {
    return list->fwd.link.next == &list->back.link;
}

static inline List_Link *List_GetHead(const List_Head *list) {
    return List_IsEmpty(list) ? NULL : list->fwd.link.next;
}
static inline List_Link *List_GetTail(const List_Head *list) {
    return List_IsEmpty(list) ? NULL : list->back.link.prev;
}
#define OFFSETOF(Container, Field) ((char *)&((Container *)4)->Field - (char *)(Container *)4)
#define List_Link_Container(Container, Link, link) ((Container *)((char *)(link) - OFFSETOF(Container, Link)))

static inline void List_Init(List_Head *list)
{
    list->fwd.link.next = &list->back.link;
    list->fwd.link.prev = NULL;
    list->back.link.prev = &list->fwd.link;
}

static inline void List_AddHead(List_Head *list, List_Link *link)
{
    List_Link *first = list->fwd.link.next;
    link->next = first;
    link->prev = &list->fwd.link;
    first->prev = link;
    list->fwd.link.next = link;
}

static inline void List_AddTail(List_Head *list, List_Link *link)
{
    List_Link *last = list->back.link.prev;
    link->prev = last;
    link->next = &list->back.link;
    last->next = link;
    list->back.link.prev = link;
}

static inline void List_Remove(List_Link *link)
{
    link->prev->next = link->next;
    link->next->prev = link->prev;
}

///////////////////////////////////////////////////////////////////////////////
// ...2-way linked lists
///////////////////////////////////////////////////////////////////////////////


enum {
    Coroutines_Idle,
    Coroutines_Starting,
    Coroutines_Started,
    Coroutines_Active,
    Coroutines_Stopping
};

enum {
    Chunk_Initial,
    Chunk_Create,
    Chunk_Enter    
};

enum {
    Coroutine_Constructing,
    Coroutine_Free,
    Coroutine_Idle,
    Coroutine_Running,
    Coroutine_Waiting,
    Coroutine_Complete
};

enum {
    Coroutines_Init,
    Coroutines_AllocatedChunk,
    Coroutines_CoroutineComplete,
};

struct Coroutine {
    List_Link link;
    jmp_buf buf;
    void *this;
    Coroutine_YieldCallback on_yield;
    Coroutine_Start start;
    void *entry_param;
    void *value;
    char state;
    char action;
};

typedef struct Coroutines Coroutines;

struct Coroutines {
    pthread_mutex_t mutex;
    jmp_buf controller;
    jmp_buf chunk_allocated;

    // singletons
    Coroutine *tip;     // top of stack chunk
    Coroutine *active;  // currently running coroutine
    Coroutine *primary; // Coroutine_Run coroutine

    // lists
    List_Head free;
    List_Head inactive;     // idle or complete
    List_Head runable;      // running or waiting to run
    List_Head waiting;      // yielded / waiting to run
    Semaphore waiting_sem;

    // state
    char state;
};

Coroutines g_c;

static void stack_chunk_chunk(Coroutine *parent);
static void stack_chunk_base(Coroutine *parent);

static void Coroutine_PrimeStackChunks()
{
    unsigned char chunk_of_stack[COROUTINE_STACK_SIZE];
    chunk_of_stack[0] = 0xde;
    chunk_of_stack[1] = 0xad;
    chunk_of_stack[2] = 0xbe;
    chunk_of_stack[3] = 0xef;
    chunk_of_stack[COROUTINE_STACK_SIZE - 4] = 0xde;
    chunk_of_stack[COROUTINE_STACK_SIZE - 3] = 0xad;
    chunk_of_stack[COROUTINE_STACK_SIZE - 2] = 0xbe;
    chunk_of_stack[COROUTINE_STACK_SIZE - 1] = 0xef;
    stack_chunk_base(NULL);
}

static void stack_chunk_chunk(Coroutine *parent){
    unsigned char chunk_of_stack[COROUTINE_STACK_SIZE];
    chunk_of_stack[0] = 0xde;
    chunk_of_stack[1] = 0xad;
    chunk_of_stack[2] = 0xbe;
    chunk_of_stack[3] = 0xef;
    chunk_of_stack[COROUTINE_STACK_SIZE - 4] = 0xde;
    chunk_of_stack[COROUTINE_STACK_SIZE - 3] = 0xad;
    chunk_of_stack[COROUTINE_STACK_SIZE - 2] = 0xbe;
    chunk_of_stack[COROUTINE_STACK_SIZE - 1] = 0xef;
    stack_chunk_base(parent);
}

static void Coroutine_RunNext()
{
    // arrvie here with mutex unlocked
    Semaphore_Claim(&g_c.waiting_sem);
    int r = pthread_mutex_lock(&g_c.mutex);
    assert(r == 0);
    assert(!List_IsEmpty(&g_c.runable));
    Coroutine *next = List_Link_Container(Coroutine, link, List_GetHead(&g_c.runable));
    assert(next->state == Coroutine_Running);
    longjmp(next->buf, Chunk_Enter);
    assert(false);
}

static void stack_chunk_base(Coroutine *parent){
    Coroutine here;
    here.state = Coroutine_Constructing;
    switch (setjmp(here.buf)) {
    case Chunk_Initial:
        // got here for the first time
        // parent now has a chunk_of_stack - add it to the free list
        if (parent) {
            assert(parent->state == Coroutine_Constructing);
            parent->state = Coroutine_Free;
            List_AddHead(&g_c.free, &parent->link);
        }
        // note that here is the tip of the chunk-claim stack
        g_c.tip = &here;

        // return to the coroutine allocator
        longjmp(g_c.chunk_allocated, 1);
    case Chunk_Create:
        // request to create a new chunk on the stack
        assert(here.state == Coroutine_Constructing);
        stack_chunk_chunk(&here);
        assert(false);
    case Chunk_Enter:
        // request to start a coroutine (ie use the chunk for a coroutine)
        // arrive here with mutex locked
        assert(here.state == Coroutine_Running);
        g_c.active = &here;
        int r = pthread_mutex_unlock(&g_c.mutex);
        assert(r == 0);
        here.value = here.start(here.entry_param);
        r = pthread_mutex_lock(&g_c.mutex);
        assert(r == 0);
        g_c.active = NULL;
        assert(here.state == Coroutine_Running);
        List_Remove(&here.link);
        here.state = Coroutine_Complete;
        List_AddTail(&g_c.inactive, &here.link);
        // coroutine has completed
        if (g_c.primary == &here) {
            // if primary coroutine - return to Coroutine_Run
            longjmp(g_c.controller, Coroutines_CoroutineComplete);
        }
        r = pthread_mutex_unlock(&g_c.mutex);
        assert(r == 0);
        Coroutine_RunNext();
        assert(false);
    }
}

void Coroutine_StartSystem()
{
    assert(g_c.state == Coroutines_Idle);

    pthread_mutexattr_t attr;
    int r = pthread_mutexattr_init(&attr);
    assert(r == 0);
    r = pthread_mutexattr_settype(&attr, PTHREAD_MUTEX_RECURSIVE);
    assert(r == 0);
    r = pthread_mutex_init(&g_c.mutex, &attr);
    assert(r == 0);
    r = pthread_mutexattr_destroy(&attr);
    assert(r == 0);

    g_c.state = Coroutines_Starting;

    g_c.tip = NULL;
    g_c.active = NULL;

    List_Init(&g_c.free);
    List_Init(&g_c.inactive);
    List_Init(&g_c.runable);
    List_Init(&g_c.waiting);
    Semaphore_ctor(&g_c.waiting_sem, 0);

    // prime the chunk system
    if (!setjmp(g_c.chunk_allocated)){
        Coroutine_PrimeStackChunks();
        assert(false);
    }
    assert(g_c.state == Coroutines_Starting);
    g_c.state = Coroutines_Started;
}

void Coroutine_StopSystem()
{
    int r = pthread_mutex_lock(&g_c.mutex);
    assert(r == 0);
    assert(g_c.state == Coroutines_Started);
    g_c.state = Coroutines_Stopping;

    assert(List_IsEmpty(&g_c.inactive));
    Semaphore_dtor(&g_c.waiting_sem);

    assert(g_c.state == Coroutines_Stopping);
    pthread_mutex_unlock(&g_c.mutex);
    assert(r == 0);
    g_c.state = Coroutines_Idle;
    r = pthread_mutex_destroy(&g_c.mutex);
    assert(r == 0);
}

void *Coroutine_Run(Coroutine *cor, void *value){
    int r = pthread_mutex_lock(&g_c.mutex);
    assert(r == 0);
    assert(g_c.state == Coroutines_Started);
    g_c.state = Coroutines_Active;
    g_c.primary = cor;
    Coroutine_Continue(g_c.primary, value, true);

    if (!setjmp(g_c.controller)){
        pthread_mutex_unlock(&g_c.mutex);
        assert(r == 0);
        // start the first coroutine
        Coroutine_RunNext();
    }
    // arrive here with mutex locked
    assert(List_IsEmpty(&g_c.runable));
    assert(List_IsEmpty(&g_c.waiting));
    assert(g_c.state == Coroutines_Active);
    g_c.state = Coroutines_Started;
    pthread_mutex_unlock(&g_c.mutex);
    assert(r == 0);
    return Coroutine_GetValue(cor);
}

Coroutine *Coroutine_New(void *this, Coroutine_YieldCallback on_yield, Coroutine_Start start){
    assert(g_c.state == Coroutines_Started || g_c.state == Coroutines_Active);

    // if none free - add one
    if (List_IsEmpty(&g_c.free)){
        if (!setjmp(g_c.chunk_allocated)){
            longjmp(g_c.tip->buf, Chunk_Create);
        }
    }

    Coroutine *cor = List_Link_Container(Coroutine, link, List_GetHead(&g_c.free));
    assert(cor->state == Coroutine_Free);
    cor->state = Coroutine_Idle;
    cor->this = this;
    cor->start = start;
    cor->on_yield = on_yield;
    cor->value = NULL;
    List_Remove(&cor->link);
    List_AddHead(&g_c.inactive, &cor->link);

    return cor;
}

void Coroutine_Delete(Coroutine *cor){
    assert(cor->state == Coroutine_Idle || cor->state == Coroutine_Complete);
    cor->state = Coroutine_Free;
    List_Remove(&cor->link);
    List_AddTail(&g_c.free, &cor->link);
}

void Coroutine_Continue(Coroutine *cor, void *value, bool early){
    int r = pthread_mutex_lock(&g_c.mutex);
    assert(r == 0);
    assert(cor->state == Coroutine_Idle || cor->state == Coroutine_Waiting);
    cor->entry_param = value;
    cor->state = Coroutine_Running;
    List_Remove(&cor->link);
    if ( early ) {
        List_AddHead(&g_c.runable, &cor->link);
    } else {
        List_AddTail(&g_c.runable, &cor->link);
    }
    r = pthread_mutex_unlock(&g_c.mutex);
    assert(r == 0);
    Semaphore_Release(&g_c.waiting_sem);
}

void *Coroutine_Yield(void *value){
    int r = pthread_mutex_lock(&g_c.mutex);
    assert(r == 0);
    Coroutine *me = g_c.active;
    assert(me && me->state == Coroutine_Running);
    me->value = value;
    me->state = Coroutine_Waiting;
    List_Remove(&me->link);
    List_AddTail(&g_c.waiting, &me->link);
    switch (setjmp(me->buf)){
    case Chunk_Initial:
        r = pthread_mutex_unlock(&g_c.mutex);
        assert(r == 0);
        me->on_yield(me->this);
        Coroutine_RunNext();
    case Chunk_Create:
        assert(false);
    case Chunk_Enter:
        // arrive here with mutex locked
        g_c.active = me;
        // when we return here - we are running again
        assert(me->state == Coroutine_Running);
        void *res = me->entry_param;
        r = pthread_mutex_unlock(&g_c.mutex);
        assert(r == 0);
        return res;
    }
    return NULL;
}

void *Coroutine_GetValue(Coroutine *cor){
    return cor->value;
}

Coroutine *Coroutine_GetActive()
{
    return g_c.active;
}

bool Coroutine_IsRunning(Coroutine *cor)
{
    int state = cor->state;
    return state == Coroutine_Running || state == Coroutine_Waiting;
}
