#include <stdio.h>
#include <stdint.h>
#include <string>
#include <algorithm>
#include <map>
#include <vector>

#include <sstream>
#include <iostream>
#include <iterator>

#include <boost/pointer_cast.hpp>

#include <boost/shared_ptr.hpp>

// add template subclass of istream for sstring

struct item {
    typedef boost::shared_ptr<item> ptr;
    enum { T_UNDEF, T_DICT, T_LIST, T_INT, T_STR, T_TINY, T_FLOAT };
    static void bencode_integer(std::ostream& os, int64_t value)
    {
        if (value==0) {
            os << '0';
            return;
        }
        if (value<0) {
            os << '-';
            value=-value;
        }
        std::string num;
        while (value)
        {
            num += '0'+(value % 10);
            value /= 10;
        }
        std::copy(num.rbegin(), num.rend(), std::ostream_iterator<char>(os));
    }
    static int64_t bdecode_integer(std::istream& is)
    {
        int64_t value=0;
        bool negative=false;
        char c= is.get();
        if (c=='-') {
            negative=true;
        }
        while (c>='0' && c<='9')
        {
            value *= 10;
            value += c-'0';
            c= is.get();
        }
        is.unget();

        return negative ? -value : value;
    }

    virtual void bencode(std::ostream& os) const= 0;

    static item::ptr bdecode(std::istream& is);

    virtual int type()=0;
};
typedef std::vector<item::ptr> item_vector;
typedef std::map<item::ptr,item::ptr> item_dict;
typedef std::string item_string;
typedef int64_t item_integer;

struct bstring : item {
    item_string _s;

    virtual void bencode(std::ostream& os) const
    {
        bencode_integer(os, _s.size());
        os << ':';
        copy(_s.begin(), _s.end(), std::ostream_iterator<char>(os));
    }

    static item::ptr bdecode(std::istream& is)
    {
        int64_t length= bdecode_integer(is);
        if (is.get() != ':')
            throw "invalid enc#1";
        item_string s;
        s.resize(length);
        is.read(&s[0], length);
        if (is.gcount()!=length)
            throw "invalid enc#2";
        return item::ptr(new bstring(s));
    }
    explicit bstring(const item_string& s) : _s(s) { }

    item_string& get() { return _s; }
    virtual int type() { return T_STR; }
};

struct dict : item {
    item_dict _d;

    virtual void bencode(std::ostream& os) const
    {
        os << 'd';
        for (item_dict::const_iterator v= _d.begin() ; v!=_d.end() ; v++)
        {
            (*v).first->bencode(os);
            (*v).second->bencode(os);
        }
        os << 'e';
    }

    static item::ptr bdecode(std::istream& is)
    {
        item_dict d;
        if (is.get() != 'd')
            throw "invalid enc#3";
        while (is.peek()!='e') {
            item::ptr key= item::bdecode(is);
            d.insert(item_dict::value_type(key, item::bdecode(is)));
        }
        if (is.get() != 'e')
            throw "invalid enc#4";
        return item::ptr(new dict(d));
    }

    explicit dict(const item_dict& d) : _d(d) { }

item_dict& get() { return _d; }
    virtual int type() { return T_DICT; }
};

struct list : item {
    item_vector _l;

    virtual void bencode(std::ostream& os) const
    {
        os << 'l';
        for (item_vector::const_iterator v= _l.begin() ; v!=_l.end() ; v++)
        {
            (*v)->bencode(os);
        }
        os << 'e';
    }

    static item::ptr bdecode(std::istream& is)
    {
        item_vector l;
        if (is.get() != 'l')
            throw "invalid enc#5";
        while (is.peek()!='e') {
            l.push_back(item::bdecode(is));
        }
        if (is.get() != 'e')
            throw "invalid enc#6";
        return item::ptr(new list(l));
    }

    explicit list(const item_vector& l) : _l(l) { }

item_vector& get() { return _l; }
    virtual int type() { return T_LIST; }
};

struct integer : item {
    item_integer _i;

    virtual void bencode(std::ostream& os) const
    {
        os << 'i';
        bencode_integer(os, _i);
        os << 'e';
    }


    static item::ptr bdecode(std::istream& is)
    {
        if (is.get() != 'i')
            throw "invalid enc#7";
        item_integer i= bdecode_integer(is);
        if (is.get() != 'e')
            throw "invalid enc#8";
        return item::ptr(new integer(i));
    }

    explicit integer(item_integer i) : _i(i) { }
    item_integer& get() { return _i; }
    virtual int type() { return T_INT; }
};

item::ptr item::bdecode(std::istream& is)
{
    switch(is.peek())
    {
        case 'i': return integer::bdecode(is);
        case 'd': return dict::bdecode(is);
        case 'l': return list::bdecode(is);
        default:
                  if (is.peek()>='0' && is.peek()<='9') {
                      return bstring::bdecode(is);
                  }
                  else {
                      throw "invalid enc#9";
                  }
    }
}

bool operator <(item::ptr a,item::ptr b)
{
    if (a->type() < b->type())
        return true;
    if (a->type() > b->type())
        return false;
    switch(a->type()) {
        case item::T_INT: return boost::dynamic_pointer_cast<integer>(a)->get() < boost::dynamic_pointer_cast<integer>(b)->get();
        case item::T_STR: return boost::dynamic_pointer_cast<bstring>(a)->get() < boost::dynamic_pointer_cast<bstring>(b)->get();
    }
    return false;
}

// get the node in the tree as specified by the path
// <sub>.<sub>.<sub>
// where <sub> can be either a number, or a string depending on where in the tree we are.
item::ptr getnode(item::ptr tree, const std::string& path)
{
    size_t off= 0;

    while (off)
    {
        size_t dot= path.find('.', off);

        if (tree->type()==item::T_LIST) {
            char *p=0;
            int idx= strtoul(&path[off], &p, 0);
            if (p==&path[off]) {
                throw "path: expected integer";
            }
            tree= boost::dynamic_pointer_cast<list>(tree)->get()[idx];
        }
        else if (tree->type()==item::T_DICT) {
            tree= boost::dynamic_pointer_cast<dict>(tree)->get()[item::ptr(new bstring(path.substr(off, dot)))];
        }
        else {
            throw "path: no dict or list";
        }

        if (dot==std::string::npos)
            break;
        off= dot+1;
    }
    return tree;
}

#include "benctests.h"

void hexdump(std::istream& is)
{
    int c;
    while(1)
    {
        c=is.get();
        if (c==-1)
            break;
        printf(" %02x", uint8_t(c)&0xff);
    }
    printf("\n");
}
void test(const std::string& str)
{
    try {

    std::stringstream ss1(str);
    item::ptr tree= item::bdecode(ss1);
    if ((int)ss1.tellg()!=(int)str.size()) {
        printf("not processed entirely: pos=%x, end=%x\n", (int)ss1.tellg(), (int)str.size());
        printf("nextchar: %02x\n", ss1.get());
    }
    std::stringstream ss2;
    tree->bencode(ss2);

    if (str.size() != ss2.str().size() || !equal(str.begin(), str.end(), ss2.str().begin())) {
        ss1.seekg(0);
        printf("orig: (%d)", (int)str.size()); hexdump(ss1);
        printf("new : (%d)", (int)ss2.str().size()); hexdump(ss2);
        printf("error\n");
    }
    else 
        printf("ok\n");

//    item::ptr a= getnode(x, "info.length");
//    printf("info.length:");
//    x->bencode(std::cout);
//    printf("\n");
    }
    catch (const char*msg)
    {
        printf("ERROR: %s\n", msg);
    }
}
std::string readfile(const char*name)
{
    FILE *f= fopen(name, "rb");
    if (f==NULL) {
        perror(name);
        return "";
    }
    std::string data;
    while (!feof(f))
    {
        data.resize(data.size()+0x10000);
        size_t n= fread(&data[data.size()-0x10000], 1, 0x10000, f);
        data.resize(data.size()-0x10000+n);
        if (n==0)
            break;
    }
    fclose(f);
    return data;
}
int main(int argc, char**argv)
{
    if (argc>1) {
        for (int i=1 ; i<argc ; i++) {
            test(readfile(argv[i]));
        }
    }
    else {
        for (unsigned i=0 ; i<NBTESTS ; i++) {
            printf("---- %4d\n", i);
            test(btests[i]);
        }
    }
}
