#include "coroutine.h"
#include <assert.h>
#include <setjmp.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdio.h>
#include <semaphore.h>
#include <fcntl.h>
#include <unistd.h>

#define COROUTINE_STACK_SIZE 16384
#define COROUTINE_STARTUP_STACK_SIZE 1024


///////////////////////////////////////////////////////////////////////////////
// 2-way linked lists...
//
// Broght inline here to avoid namespace polution
///////////////////////////////////////////////////////////////////////////////

typedef struct List_Link List_Link;
struct List_Link {
    List_Link *next;
    List_Link *prev;
};

typedef struct List_Head List_Head;
struct List_Head {
    union {
        struct {
            List_Link link;
            List_Link *filler;
        } fwd;
        struct {
            List_Link *filler;
            List_Link link;
        } back;
    };
};

static inline bool List_IsEmpty(const List_Head *list) {
    return list->fwd.link.next == &list->back.link;
}

static inline List_Link *List_GetHead(const List_Head *list) {
    return List_IsEmpty(list) ? NULL : list->fwd.link.next;
}
static inline List_Link *List_GetTail(const List_Head *list) {
    return List_IsEmpty(list) ? NULL : list->back.link.prev;
}
#define OFFSETOF(Container, Field) ((char *)&((Container *)4)->Field - (char *)(Container *)4)
#define List_Link_Container(Container, Link, link) ((Container *)((char *)(link) - OFFSETOF(Container, Link)))

static inline void List_Init(List_Head *list)
{
    list->fwd.link.next = &list->back.link;
    list->fwd.link.prev = NULL;
    list->back.link.prev = &list->fwd.link;
}

static inline void List_AddHead(List_Head *list, List_Link *link)
{
    List_Link *first = list->fwd.link.next;
    link->next = first;
    link->prev = &list->fwd.link;
    first->prev = link;
    list->fwd.link.next = link;
}

static inline void List_AddTail(List_Head *list, List_Link *link)
{
    List_Link *last = list->back.link.prev;
    link->prev = last;
    link->next = &list->back.link;
    last->next = link;
    list->back.link.prev = link;
}

static inline void List_Remove(List_Link *link)
{
    link->prev->next = link->next;
    link->next->prev = link->prev;
}

///////////////////////////////////////////////////////////////////////////////
// ...2-way linked lists
///////////////////////////////////////////////////////////////////////////////


enum {
    Coroutines_Idle,
    Coroutines_Starting,
    Coroutines_Started,
    Coroutines_Active,
    Coroutines_Stopping
};

enum {
    Chunk_Initial,
    Chunk_Create,
    Chunk_Enter    
};

enum {
    Coroutine_Constructing,
    Coroutine_Free,
    Coroutine_Idle,
    Coroutine_Running,
    Coroutine_Waiting,
    Coroutine_Complete
};

enum {
    Coroutines_Init,
    Coroutines_AllocatedChunk,
    Coroutines_CoroutineComplete,
};

struct Coroutine {
    List_Link link;
    jmp_buf buf;
    void *this;
    Coroutine_YieldCallback on_yield;
    Coroutine_Start start;
    void *entry_param;
    void *value;
    char state;
    char action;
};

typedef struct Coroutines Coroutines;

struct Coroutines {
    jmp_buf controller;
    jmp_buf chunk_allocated;

    // singletons
    Coroutine *tip;     // top of stack chunk
    Coroutine *active;  // currently running coroutine
    Coroutine *primary; // Coroutine_Run coroutine

    // lists
    List_Head free;
    List_Head inactive;     // idle or complete
    List_Head runable;      // running or waiting to run
    List_Head waiting;      // yielded / waiting to run
    sem_t *waiting_sem;

    // state
    char state;
};

Coroutines g_c;

static void stack_chunk_chunk(Coroutine *parent);
static void stack_chunk_base(Coroutine *parent);

static void Coroutine_PrimeStackChunks()
{
    unsigned char chunk_of_stack[COROUTINE_STACK_SIZE];
    chunk_of_stack[0] = 0xde;
    chunk_of_stack[1] = 0xad;
    chunk_of_stack[2] = 0xbe;
    chunk_of_stack[3] = 0xef;
    chunk_of_stack[COROUTINE_STACK_SIZE - 4] = 0xde;
    chunk_of_stack[COROUTINE_STACK_SIZE - 3] = 0xad;
    chunk_of_stack[COROUTINE_STACK_SIZE - 2] = 0xbe;
    chunk_of_stack[COROUTINE_STACK_SIZE - 1] = 0xef;
    stack_chunk_base(NULL);
}

static void stack_chunk_chunk(Coroutine *parent){
    unsigned char chunk_of_stack[COROUTINE_STACK_SIZE];
    chunk_of_stack[0] = 0xde;
    chunk_of_stack[1] = 0xad;
    chunk_of_stack[2] = 0xbe;
    chunk_of_stack[3] = 0xef;
    chunk_of_stack[COROUTINE_STACK_SIZE - 4] = 0xde;
    chunk_of_stack[COROUTINE_STACK_SIZE - 3] = 0xad;
    chunk_of_stack[COROUTINE_STACK_SIZE - 2] = 0xbe;
    chunk_of_stack[COROUTINE_STACK_SIZE - 1] = 0xef;
    stack_chunk_base(parent);
}

static void Coroutine_RunNext()
{
    sem_wait(g_c.waiting_sem);
    assert(!List_IsEmpty(&g_c.runable));
    Coroutine *next = List_Link_Container(Coroutine, link, List_GetHead(&g_c.runable));
    assert(next->state == Coroutine_Running);
    longjmp(next->buf, Chunk_Enter);
    assert(false);
}

static void stack_chunk_base(Coroutine *parent){
    Coroutine here;
    here.state = Coroutine_Constructing;
    switch (setjmp(here.buf)) {
    case Chunk_Initial:
        // got here for the first time
        // parent now has a chunk_of_stack - add it to the free list
        if (parent) {
            assert(parent->state == Coroutine_Constructing);
            parent->state = Coroutine_Free;
            List_AddHead(&g_c.free, &parent->link);
        }
        // note that here is the tip of the chunk-claim stack
        g_c.tip = &here;

        // return to the coroutine allocator
        longjmp(g_c.chunk_allocated, 1);
    case Chunk_Create:
        // request to create a new chunk on the stack
        assert(here.state == Coroutine_Constructing);
        stack_chunk_chunk(&here);
        assert(false);
    case Chunk_Enter:
        // request to start a coroutine (ie use the chunk for a coroutine)
        assert(here.state == Coroutine_Running);
        g_c.active = &here;
        here.value = here.start(here.entry_param);
        g_c.active = NULL;
        assert(here.state == Coroutine_Running);
        List_Remove(&here.link);
        here.state = Coroutine_Complete;
        List_AddTail(&g_c.inactive, &here.link);
        // coroutine has completed
        if (g_c.primary == &here) {
            // if primary coroutine - return to Coroutine_Run
            longjmp(g_c.controller, Coroutines_CoroutineComplete);
        }
        Coroutine_RunNext();
        assert(false);
    }
}

void Coroutine_StartSystem()
{
    assert(g_c.state == Coroutines_Idle);
    g_c.state = Coroutines_Starting;

    g_c.tip = NULL;
    g_c.active = NULL;

    List_Init(&g_c.free);
    List_Init(&g_c.inactive);
    List_Init(&g_c.runable);
    List_Init(&g_c.waiting);
    char tbuf[256];
    snprintf(tbuf, sizeof(tbuf), "/coroutine_waiting_sem_%d", getpid());
    g_c.waiting_sem = sem_open(tbuf, O_CREAT, 0644, 0);
    sem_unlink(tbuf);
    assert(g_c.waiting_sem != SEM_FAILED);

    // prime the chunk system
    if (!setjmp(g_c.chunk_allocated)){
        Coroutine_PrimeStackChunks();
        assert(false);
    }
    assert(g_c.state == Coroutines_Starting);
    g_c.state = Coroutines_Started;
}

void Coroutine_StopSystem()
{
    assert(g_c.state == Coroutines_Started);
    g_c.state = Coroutines_Stopping;

    assert(List_IsEmpty(&g_c.inactive));
    sem_close(g_c.waiting_sem);
    g_c.waiting_sem = NULL;

    assert(g_c.state == Coroutines_Stopping);
    g_c.state = Coroutines_Idle;
}

void *Coroutine_Run(Coroutine *cor, void *value){
    assert(g_c.state == Coroutines_Started);
    g_c.state = Coroutines_Active;
    g_c.primary = cor;
    Coroutine_Continue(g_c.primary, value, true);

    if (!setjmp(g_c.controller)){
        // start the first coroutine
        Coroutine_RunNext();
    }
    assert(List_IsEmpty(&g_c.runable));
    assert(List_IsEmpty(&g_c.waiting));
    assert(g_c.state == Coroutines_Active);
    g_c.state = Coroutines_Started;
    return Coroutine_GetValue(cor);
}

Coroutine *Coroutine_New(void *this, Coroutine_YieldCallback on_yield, Coroutine_Start start){
    assert(g_c.state == Coroutines_Started || g_c.state == Coroutines_Active);

    // if none free - add one
    if (List_IsEmpty(&g_c.free)){
        if (!setjmp(g_c.chunk_allocated)){
            longjmp(g_c.tip->buf, Chunk_Create);
        }
    }

    Coroutine *cor = List_Link_Container(Coroutine, link, List_GetHead(&g_c.free));
    assert(cor->state == Coroutine_Free);
    cor->state = Coroutine_Idle;
    cor->this = this;
    cor->start = start;
    cor->on_yield = on_yield;
    cor->value = NULL;
    List_Remove(&cor->link);
    List_AddHead(&g_c.inactive, &cor->link);

    return cor;
}

void Coroutine_Delete(Coroutine *cor){
    assert(cor->state == Coroutine_Idle || cor->state == Coroutine_Complete);
    cor->state = Coroutine_Free;
    List_Remove(&cor->link);
    List_AddTail(&g_c.free, &cor->link);
}

void Coroutine_Continue(Coroutine *cor, void *value, bool early){
    assert(cor->state == Coroutine_Idle || cor->state == Coroutine_Waiting);
    cor->entry_param = value;
    cor->state = Coroutine_Running;
    List_Remove(&cor->link);
    if ( early ) {
        List_AddHead(&g_c.runable, &cor->link);
    } else {
        List_AddTail(&g_c.runable, &cor->link);
    }
    sem_post(g_c.waiting_sem);
}

void *Coroutine_Yield(void *value){
    Coroutine *me = g_c.active;
    assert(me && me->state == Coroutine_Running);
    me->value = value;
    me->state = Coroutine_Waiting;
    List_Remove(&me->link);
    List_AddTail(&g_c.waiting, &me->link);
    if (!setjmp(me->buf)){
        me->on_yield(me->this);
        Coroutine_RunNext();
    }
    g_c.active = me;
    // when we return here - we are running again
    assert(me->state == Coroutine_Running);
    return me->entry_param;
}

void *Coroutine_GetValue(Coroutine *cor){
    return cor->value;
}

Coroutine *Coroutine_GetActive()
{
    return g_c.active;
}

bool Coroutine_IsRunning(Coroutine *cor)
{
    return cor->state == Coroutine_Running || cor->state == Coroutine_Waiting;
}
