#include <stdio.h>
#include <stdint.h>
#include <string>
#include <string.h>
#include <algorithm>
#include <map>
#include <vector>

#include <sstream>
#include <iostream>
#include <iterator>

#include <boost/pointer_cast.hpp>

#include <boost/shared_ptr.hpp>

// todo: improve operator<

// add template subclass of istream for sstring
//
bool g_morehex= 1;
struct encoder;

class dict;
class list;
class integer;
class bstring;
struct item {
    typedef boost::shared_ptr<item> ptr;
    enum { T_UNDEF, T_DICT, T_LIST, T_INT, T_STR, T_TINY, T_FLOAT };

    virtual void encode(std::ostream& os, encoder* e) const=0;
    virtual int type() const=0;

    dict *asdict() { if (type()!=T_DICT) throw "not a dict"; return boost::dynamic_pointer_cast<dict>(this); }
    list *aslist() { if (type()!=T_LIST) throw "not a list"; return boost::dynamic_pointer_cast<list>(this); }
    integer *asinteger() { if (type()!=T_INT) throw "not a integer"; return boost::dynamic_pointer_cast<integer>(this); }
    bstring *asstring() { if (type()!=T_STR) throw "not a string"; return boost::dynamic_pointer_cast<bstring>(this); }
/*
    int threewaycompare(const item& b) const
    {
        if (type()<b.type())
            return -1;
        if (type()>b.type())
            return 1;
        return 0;
    }
*/
};

typedef std::vector<item::ptr> item_vector;
typedef std::map<item::ptr,item::ptr> item_dict;
typedef std::string item_string;
typedef int64_t item_integer;

bool isbinary(const item_string& s)
{
    int x[2]= {0,0};
    for (item_string::const_iterator p= s.begin() ; p!=s.end() ; p++)
    {
        char c= *p;
        if (c==9 || c==10 || c==13 || (c>=32 && c<=126))
            x[0]++;
        else
            x[1]++;
    }
    return x[1]!=0;
    // 0.5 < x[0]/x[1] < 1.5
    //return (x[1]<2*x[0]) && (2*x[0]<3*x[1]);
}
struct encoder {
    virtual void encode(std::ostream& os, const item_string& s)=0;
    virtual void encode(std::ostream& os, const item_dict& d)=0;
    virtual void encode(std::ostream& os, const item_vector& l)=0;
    virtual void encode(std::ostream& os, const item_integer& i)=0;
    virtual void encode(std::ostream& os)=0;

    virtual item::ptr decode(std::istream& is)=0;
};

/*
template<typename T> 
class threewaycomparable {
    virtual T& get()=0;
    int threewaycompare(const T& b) const
    {
        int c= item::threewaycompare(b);
        if (c)
            return c;
        if (get()<b.get())
            return -1;
        if (get()>b.get())
            return 1;
        return 0;
    }
};
*/

struct bstring : item  /* , threewaycomparable<item_string> */ {
    item_string _s;

    explicit bstring(const item_string& s) : _s(s) { }

    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _s);
    }

    item_string& get() { return _s; }
    virtual int type() const { return T_STR; }
};

struct dict : item  /* , threewaycomparable<item_dict> */ {
    item_dict _d;

    dict() { }
    explicit dict(const item_dict& d) : _d(d) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _d);
    }

    item_dict& get() { return _d; }
    virtual int type() const { return T_DICT; }
};

struct list : item  /* , threewaycomparable<item_vector> */ {
    item_vector _l;

    list() { }
    explicit list(const item_vector& l) : _l(l) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _l);
    }

    item_vector& get() { return _l; }
    virtual int type() const { return T_LIST; }
};

struct integer : item  /* , threewaycomparable<item_integer> */ {
    item_integer _i;


    explicit integer(item_integer i) : _i(i) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _i);
    }
    item_integer& get() { return _i; }
    virtual int type() const { return T_INT; }
};
struct undef : item {

    explicit undef() { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os);
    }
    void get() { }
    virtual int type() const { return T_UNDEF; }
};

bool operator<(item::ptr a,item::ptr b)
{
    if (a->type() < b->type())
        return true;
    if (a->type() > b->type())
        return false;
    switch(a->type()) {
        case item::T_INT: return a->asinteger()->get() < b->asinteger()->get();
        case item::T_STR: return a->asstring()->get() < b->asstring()->get();
    }
    return false;
}

// get the node in the tree as specified by the path
// <sub>.<sub>.<sub>
// where <sub> can be either a number, or a string depending on where in the tree we are.
item::ptr getnode(item::ptr tree, const std::string& path)
{
    size_t off= 0;

    while (1)
    {
        size_t dot= path.find('.', off);

        if (tree->type()==item::T_LIST) {
            char *p=0;
            int idx= strtoul(&path[off], &p, 0);
            if (p==&path[off])
                throw "path: expected integer";
            tree= tree->aslist()->get()[idx];
        }
        else if (tree->type()==item::T_DICT) {
            tree= tree->asdict()->get()[item::ptr(new bstring(path.substr(off, dot)))];
        }
        else {
            throw "path: no dict or list";
        }

        if (dot==std::string::npos)
            break;
        off= dot+1;
    }
    return tree;
}

struct gsmkencoder : encoder {
    enum { T_UNDEF, T_DICT, T_LIST, T_INT, T_STR, T_TINY, T_FLOAT };
    static void encode_header(std::ostream& os, int type, int64_t size)
    {
        os << (char)((size>=16?0x80:0) | (type<<4) | (size&15));
        size>>=4;
        while (size!=0 && size!=-1) {
            os << (char)(((-127>size || size>127)?0x80:0) | (size&127));
            size>>=7;
        }
    }
    static void decode_header(std::istream& is, int &type, int64_t& size)
    {
        if (is.eof())
            throw "invalid genc#0";
        int8_t byte= is.get();
        size= byte&15;
        type= (byte>>4)&7;
        int shift=4;
        while (byte&0x80) {
            if (is.eof())
                throw "invalid genc#1";
            byte= is.get();
            size |= (byte&127)<<shift;
            shift+=7;
        }
    }

    virtual void encode(std::ostream& os, const item_string& s)
    {
        encode_header(os, T_STR, s.size());
        copy(s.begin(), s.end(), std::ostream_iterator<char>(os));
    }

    item::ptr decodestring(std::istream& is, int64_t length)
    {
        item_string s;
        s.resize(length);
        is.read(&s[0], length);
        if (is.gcount()!=length)
            throw "invalid genc#2";
        return item::ptr(new bstring(s));
    }

    virtual void encode(std::ostream& os, const item_dict& d)
    {
        encode_header(os, T_DICT, d.size());
        for (item_dict::const_iterator v= d.begin() ; v!=d.end() ; v++)
        {
            (*v).first->encode(os, this);
            (*v).second->encode(os, this);
        }
    }

    item::ptr decodedict(std::istream& is, int64_t size)
    {
        item_dict d;
        while (size--) {
            item::ptr key= decode(is);
            d.insert(item_dict::value_type(key, decode(is)));
        }
        if (is.eof())
            throw "invalid genc#3";
        return item::ptr(new dict(d));
    }

    virtual void encode(std::ostream& os, const item_vector& l)
    {
        encode_header(os, T_LIST, l.size());
        for (item_vector::const_iterator v= l.begin() ; v!=l.end() ; v++)
        {
            (*v)->encode(os, this);
        }
    }

    item::ptr decodelist(std::istream& is, int64_t size)
    {
        item_vector l;
        if (is.eof())
            throw "invalid genc#4";
        while (size--) {
            l.push_back(decode(is));
        }
        if (is.eof())
            throw "invalid genc#5";
        return item::ptr(new list(l));
    }

    virtual void encode(std::ostream& os, const item_integer& i)
    {
        if (-0x0fffffff<=i && i<=0x0fffffff) {
            encode_header(os, T_TINY, i<0?-2*i+1 : 2*i);
        }
        else {
            int64_t val= i<0 ? -i : i;
            std::vector<uint8_t> v;
            while (val!=0 && val!=-1) {
                v.push_back(val&255);
                val>>=8;
                //printf("enc %d : %02x : %llx\n", v.size(), v.back(), val);
            }
            encode_header(os, T_INT, v.size()*2+(i<0?1:0));
            std::copy(v.begin(), v.end(), std::ostream_iterator<char>(os));
        }
    }

    item::ptr decodeinteger(std::istream& is, int64_t size, bool tiny)
    {
        item_integer i;
        if (tiny) {
            i= size/2;
        }
        else {
            i= 0;
            for (int byte= 0 ; byte<size/2 ; byte++) {
                if (is.eof())
                    throw "invalid genc#6";
                i |= (uint64_t(is.get())&0xff)<<(byte*8);
                //printf("dec %d :  %llx\n", byte, i);
            }
        }
        if (size&1)
            i=-i;
        return item::ptr(new integer(i));
    }

    virtual void encode(std::ostream& os)
    {
        encode_header(os, T_UNDEF, 0);
    }
    item::ptr decodeundef(std::istream& is)
    {
        return item::ptr(new undef());
    }

    virtual item::ptr decode(std::istream& is)
    {
        int64_t size;
        int type;
        if (is.eof())
            throw "empty";
        decode_header(is, type, size);
        switch(type)
        {
            case item::T_INT:
            case item::T_TINY:
                return decodeinteger(is, size, type==item::T_TINY);
            case item::T_DICT: return decodedict(is, size);
            case item::T_LIST: return decodelist(is, size);
            case item::T_STR: return decodestring(is, size);
            case item::T_UNDEF: return decodeundef(is);
            default:
                  printf("unknown header: @%x %d %lld\n", (int)is.tellg(), type, size);
                  throw "invalid genc#7";
        }
    }
};

struct bencoder : encoder {

    static void encode_integer(std::ostream& os, int64_t value)
    {
        if (value==0) {
            os << '0';
            return;
        }
        if (value<0) {
            os << '-';
            value=-value;
        }
        item_string num;
        while (value)
        {
            num += '0'+(value % 10);
            value /= 10;
        }
        std::copy(num.rbegin(), num.rend(), std::ostream_iterator<char>(os));
    }
    static int64_t decode_integer(std::istream& is)
    {
        int64_t value=0;
        bool negative=false;
        char c= is.get();
        if (c=='-') {
            negative=true;
            c= is.get();
        }
        while (c>='0' && c<='9')
        {
            value *= 10;
            value += c-'0';
            c= is.get();
        }
        is.unget();

        return negative ? -value : value;
    }

    virtual void encode(std::ostream& os, const item_string& s)
    {
        encode_integer(os, s.size());
        os << ':';
        copy(s.begin(), s.end(), std::ostream_iterator<char>(os));
    }

    item::ptr decodestring(std::istream& is)
    {
        int64_t length= decode_integer(is);
        if (is.get() != ':')
            throw "invalid benc#1";
        item_string s;
        s.resize(length);
        is.read(&s[0], length);
        if (is.gcount()!=length)
            throw "invalid benc#2";
        return item::ptr(new bstring(s));
    }
    virtual void encode(std::ostream& os, const item_dict& d)
    {
        os << 'd';
        for (item_dict::const_iterator v= d.begin() ; v!=d.end() ; v++)
        {
            (*v).first->encode(os, this);
            (*v).second->encode(os, this);
        }
        os << 'e';
    }

    item::ptr decodedict(std::istream& is)
    {
        item_dict d;
        if (is.get() != 'd')
            throw "invalid benc#3";
        while (is.peek()!='e') {
            item::ptr key= decode(is);
            d.insert(item_dict::value_type(key, decode(is)));
        }
        if (is.get() != 'e')
            throw "invalid benc#4";
        return item::ptr(new dict(d));
    }
    virtual void encode(std::ostream& os, const item_vector& l)
    {
        os << 'l';
        for (item_vector::const_iterator v= l.begin() ; v!=l.end() ; v++)
        {
            (*v)->encode(os, this);
        }
        os << 'e';
    }

    item::ptr decodelist(std::istream& is)
    {
        item_vector l;
        if (is.get() != 'l')
            throw "invalid benc#5";
        while (is.peek()!='e') {
            l.push_back(decode(is));
        }
        if (is.get() != 'e')
            throw "invalid benc#6";
        return item::ptr(new list(l));
    }
    virtual void encode(std::ostream& os, const item_integer& i)
    {
        os << 'i';
        encode_integer(os, i);
        os << 'e';
    }


    item::ptr decodeinteger(std::istream& is)
    {
        if (is.get() != 'i')
            throw "invalid benc#7";
        item_integer i= decode_integer(is);
        if (is.get() != 'e')
            throw "invalid benc#8";
        return item::ptr(new integer(i));
    }
    virtual void encode(std::ostream& os)
    {
        throw "benc:null unsupported";
    }
    item::ptr decodeundef(std::istream& is)
    {
        throw "benc-undef unsupported";
    }

    item::ptr decode(std::istream& is)
    {
        switch(is.peek())
        {
            case -1: throw "benc#eof";
            case 'i': return decodeinteger(is);
            case 'd': return decodedict(is);
            case 'l': return decodelist(is);
            default:
                      if (is.peek()>='0' && is.peek()<='9') {
                          return decodestring(is);
                      }
                      throw "invalid benc#9";
        }
    }
};
struct printencoder : encoder {
    int level;
    printencoder() : level(0) { }

    void indent(std::ostream& os)
    {
        for (int i=0 ; i<level ; i++)
            os << "  ";
    }
    virtual void encode(std::ostream& os, const item_string& s)
    {
        if (isbinary(s)) {
            for (unsigned i=0 ; i<s.length() ; i++) {
                if (i) printf(",");
                printf("%02x", s[i]&0xff);
            }
        }
        else {
            os << "\"" << s << "\"";
        }
    }

    virtual void encode(std::ostream& os, const item_dict& d)
    {
        os << "{\n";
        level++;
        for (item_dict::const_iterator v= d.begin() ; v!=d.end() ; v++)
        {
            indent(os);
            (*v).first->encode(os, this);
            os << " => ";
            (*v).second->encode(os, this);
            os << "\n";
        }
        level--;
        indent(os);
        os << "}";
    }

    virtual void encode(std::ostream& os, const item_vector& l)
    {
        os << "[\n";
        level++;
        for (item_vector::const_iterator v= l.begin() ; v!=l.end() ; v++)
        {
            indent(os);
            (*v)->encode(os, this);
            os << "\n";
        }
        level--;
        indent(os);
        os << "]";
    }

    virtual void encode(std::ostream& os, const item_integer& i)
    {
        os << i;
    }

    virtual void encode(std::ostream& os)
    {
        os << "(null)";
    }
    virtual item::ptr decode(std::istream& is)
    {
            return item::ptr(new undef());
    }
};
struct xmlencoder : encoder {
    static bool issimple(item::ptr p)
    {
        return p->type()!=item::T_DICT && p->type()!=item::T_LIST;
    }
    static bool isvalidname(item::ptr p)
    {
        if (p->type()!=item::T_STR)
            return false;
        item_string &str= p->asstring()->get();
        if (str.size()==0)
            return false;
        char c= str[0];
        if (!(c=='_' || (c>='A'&&c<='Z') || (c>='a'&&c<='z')))
            return false;
        size_t inot= str.find_first_not_of("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz_");

        return inot==str.npos;
    }
    virtual void encode(std::ostream& os, const item_string& s)
    {
        if (s.size()>100)
            os << "largestring";
        else
            os<<s;
    }
    virtual void encode(std::ostream& os, const item_dict& d)
    {
        os<<"<dict";
        unsigned nsimple=0;
        for (item_dict::const_iterator i= d.begin() ; i!=d.end() ; i++)
        {
            if (issimple((*i).second) && isvalidname((*i).first)) {
                os << " ";
                (*i).first->encode(os, this);
                os << "=\"";
                (*i).second->encode(os, this);
                os << "\"";
                nsimple++;
            }
        }
        if (nsimple==d.size()) {
            os << "/>";
            return;
        }
        os << ">";
        int nnamed=0;
        for (item_dict::const_iterator i= d.begin() ; i!=d.end() ; i++)
        {
            if (!issimple((*i).second) && isvalidname((*i).first)) {
                os << "<";
                (*i).first->encode(os, this);
                os << ">";
                (*i).second->encode(os, this);
                os << "</";
                (*i).first->encode(os, this);
                os << ">";

                nnamed++;
            }
        }
        int unnamed=0;
        for (item_dict::const_iterator i= d.begin() ; i!=d.end() ; i++)
        {
            if (!isvalidname((*i).first)) {
                os << "<ditem key=\"";
                (*i).first->encode(os, this);
                os << "\">";
                (*i).second->encode(os, this);
                os << "</ditem>";

                unnamed++;
            }
        }
        os<<"</dict>";

        //printf("\n\n  %d  %d  %d = %d\n", nsimple, nnamed, unnamed, d.size());
    }
    virtual void encode(std::ostream& os, const item_vector& l)
    {
        os<<"<array>";
        for (item_vector::const_iterator i= l.begin() ; i!=l.end() ; i++)
        {
            os << "<aitem>";
            (*i)->encode(os, this);
            os << "</aitem>";
        }
        os<<"</array>";
    }
    virtual void encode(std::ostream& os, const item_integer& i)
    {
        os<<i;
    }
    virtual void encode(std::ostream& os)
    {
        os << "&null;";
    }

    struct taginfo {
        item_string name;
        item_string contents;
        std::map<item_string,item_string> attributes;
        bool close;
        bool open;
    };
    char skipspaces(std::istream& is)
    {
        char c;
        while(isspace(c= is.get()))
            ;
        return c;
    }
    static bool isquote(char c) { return c=='\'' || c=='"'; }
    taginfo parsetag(std::istream& is)
    {
        taginfo tag;
        char c;
        c= skipspaces(is);
//        bool allowattrs= true;
        if (c=='/') {
            tag.close= true;
            tag.open= false;
        }
        else {
            tag.open= true;
        }
        while(!isspace(c= is.get()))
            tag.name += c;
        if (!tag.open) {
            c= skipspaces(is);
            if (c!='>')
                throw "invalid xml";
            return tag;
        }
        while(1)
        {
            c= skipspaces(is);
            if (c=='/') {
                tag.close= true;
                c= skipspaces(is);
                if (c!='>')
                    throw "invalid xml";
                return tag;
            }
            else if (c=='>')
                return tag;
            is.unget();
            std::string key;
            while(!isspace(c= is.get()) && c!='=')
                key+= c;
            if (c!='=')
                c= skipspaces(is);
            if (c!='=')
                throw "invalid xml";
            c= skipspaces(is);
            std::string value;
            if (isquote(c)) {
                char quotechar= c;
                while ((c= is.get())!=quotechar)
                    value += c;
            }
            else {
                while (!isspace(c= is.get()) && c!='/' && c!='>')
                    value += c;
            }
            tag.attributes[key]= value;
        }
    }

    virtual item::ptr decode(std::istream& is)
    {
//        bool intag= false;
        std::vector<taginfo> tagstack;
        std::vector<item::ptr> itemstack;
        std::string text;
        while (!is.eof())
        {
            char c= is.get();
            if (c=='<') {
                if (!text.empty()) {
                    tagstack.back().contents= text;
                    text.clear();
                }
                tagstack.push_back(parsetag(is));
                if (tagstack.back().open) {
                    if (tagstack.back().name=="dict") {
                        itemstack.push_back(item::ptr(new dict()));
                        // todo: add tagstack.back().attributes
                    }
                    else if (tagstack.back().name=="array") {
                        itemstack.push_back(item::ptr(new list()));
                    }
                }
                if (tagstack.back().close) {
                    if (tagstack.back().name=="aitem")
                    {
                        tagstack.pop_back();
                        // add to array
                        item::ptr ai= itemstack.back();
                        itemstack.pop_back();
                        itemstack.back()->aslist()->get().push_back(ai);
                    }
                    else if (tagstack.back().name=="ditem")
                    {
                        tagstack.pop_back();
                        // add to dict
                        item::ptr ai= itemstack.back();
                        itemstack.pop_back();
                        //todo: itemstack.back()->asdict()->get()[key]= value;
                    }


                }
            }
            else {
                text += c;
            }
        }
        if (itemstack.size()!=1)
            throw "xmldecode error";
        return itemstack.front();
    }
};

#include "gsmktests.h"
#include "benctests.h"


void hexdump(std::istream& is)
{
    int c;
    while(1)
    {
        c=is.get();
        if (c==-1)
            break;
        printf(" %02x", uint8_t(c)&0xff);
    }
    printf("\n");
}

// str - ss_gsmk1 -g-> tree1 -b-> ss_benc -b-> tree2 -g-> ss_gsmk2

bool test_two(const std::string& str, encoder *first, encoder *second)
{
    try {

    std::stringstream ssin(str);
    item::ptr tree1= first->decode(ssin);
    if ((int)ssin.tellg()!=(int)str.size()) {
        printf("not processed entirely: pos=%x, end=%x\n", (int)ssin.tellg(), (int)str.size());
        printf("nextchar: %02x\n", ssin.get());
    }
    std::stringstream ss2;
    tree1->encode(ss2, second);

    ss2.seekg(0);
    item::ptr tree2= second->decode(ss2);
    if ((int)ss2.tellg()!=(int)ss2.str().size()) {
        printf("not processed entirely: pos=%x, end=%x\n", (int)ss2.tellg(), (int)ss2.str().size());
        printf("nextchar: %02x\n", ss2.get());
    }
    std::stringstream sslast;
    tree2->encode(sslast, first);


    if (str.size() != sslast.str().size() || !equal(str.begin(), str.end(), sslast.str().begin())) {
        sslast.seekg(0);
        printf("orig: (%d)", (int)str.size()); hexdump(ssin);
        printf("new : (%d)", (int)sslast.str().size()); hexdump(sslast);
        printf("error\n");
    }
    else {
        //printf("ok\n");
    }

    }
    catch (const char*msg)
    {
        printf("ERROR: %s\n", msg);
        return false;
    }
    return true;
}
bool convert(std::istream& ssin, std::ostream& ssout, encoder *dec, encoder *enc)
{
    try {

    item::ptr tree= dec->decode(ssin);
//  if ((int)ssin.tellg()!=(int)ssin.size()) {
//      printf("not processed entirely: pos=%x, end=%x\n", (int)ssin.tellg(), (int)ssin.size());
//      printf("nextchar: %02x\n", ssin.get());
//  }
    tree->encode(ssout, enc);
    }
    catch (const char*msg)
    {
        printf("ERROR: %s\n", msg);
        return false;
    }
    return true;
}

std::string readfile(const char*name)
{
    bool isstdin= strcmp(name, "-")==0;
    FILE *f= isstdin ? stdin : fopen(name, "rb");
    if (f==NULL) {
        perror(name);
        return "";
    }
    std::string data;
    while (!feof(f))
    {
        data.resize(data.size()+0x10000);
        size_t n= fread(&data[data.size()-0x10000], 1, 0x10000, f);
        data.resize(data.size()-0x10000+n);
        if (n==0)
            break;
    }
    if (!isstdin)
        fclose(f);
    return data;
}
int main(int argc, char**argv)
{
    if (argc>1) {
        for (int i=1 ; i<argc ; i++) {
            if (argc>2)
                printf("==> %s <==\n", argv[i]);
            std::string data= readfile(argv[i]);
            test_two(data, new bencoder(), new gsmkencoder());
            std::stringstream ss(data);
            convert(ss, std::cout, new bencoder(), new printencoder());
            printf("\n");
        }
    }
    else {
   //std::cout<<"<?xml version=\"1.0\"?>\n<doc>";
   //std::cout<<"</doc>";
    int b_ok=0;
    int b_err=0;
    int g_ok=0;
    int g_err=0;
        for (unsigned i=0 ; i<NBTESTS ; i++) {
            if (test_two(btests[i], new bencoder(), new gsmkencoder()))
                b_ok++;
            else
                b_err++;
        }
        for (unsigned i=0 ; i<NGTESTS ; i++) {
            if (test_two(gtests[i], new gsmkencoder(), new bencoder()))
                g_ok++;
            else
                g_err++;
        }
        printf("ok: %d/%d  err: %d/%d\n", b_ok, g_ok, b_err, g_err);
    }

}
