// g++-mp-4.5 -O3 torrentcheck.cpp -o torrentcheck -lcrypto
#include <stdio.h>
#include <stdint.h>
#include <string>
#include <string.h>
#include <algorithm>
#include <map>
#include <vector>

#include <sstream>
#include <iostream>
#include <iterator>
#include <openssl/sha.h>

#include <boost/pointer_cast.hpp>

#include <boost/shared_ptr.hpp>

// todo: add code to repeatedly get the peer list, and record what peers are active when.

// <announceurl>?info_hash=<%encodedhash>
//   peer_id= ...
//   port= ...
//   uploaded= 0
//   downloaded = 0
//   left= 123123123123
//   numwant =80
//   key= ?
//   compact= {0|1}
//   supportcrypto= {0|1}
//   event=started

int verbose= 0;
typedef std::vector<std::string> StringList;

// add template subclass of istream for sstring
struct encoder;

class dict;
class list;
class integer;
class bstring;
struct item {

    std::string _data;

    typedef boost::shared_ptr<item> ptr;
    enum { T_UNDEF, T_DICT, T_LIST, T_INT, T_STR, T_TINY, T_FLOAT };

    virtual void encode(std::ostream& os, encoder* e) const=0;
    virtual int type() const=0;

    item() { }
    item(const std::string& data) : _data(data) { }

    dict *asdict() { if (type()!=T_DICT) throw "not a dict"; return boost::dynamic_pointer_cast<dict>(this); }
    list *aslist() { if (type()!=T_LIST) throw "not a list"; return boost::dynamic_pointer_cast<list>(this); }
    integer *asinteger() { if (type()!=T_INT) throw "not a integer"; return boost::dynamic_pointer_cast<integer>(this); }
    bstring *asstring() { if (type()!=T_STR) throw "not a string"; return boost::dynamic_pointer_cast<bstring>(this); }

    const std::string& data() const { return _data; }
/*
    int threewaycompare(const item& b) const
    {
        if (type()<b.type())
            return -1;
        if (type()>b.type())
            return 1;
        return 0;
    }
*/
};

typedef std::vector<item::ptr> item_vector;
typedef std::map<item::ptr,item::ptr> item_dict;
typedef std::string item_string;
typedef int64_t item_integer;

bool isbinary(const item_string& s)
{
    int x[2]= {0,0};
    for (item_string::const_iterator p= s.begin() ; p!=s.end() ; p++)
    {
        x[(*p>>7)&1]++;
    }
    // 0.5 < x[0]/x[1] < 1.5
    return (x[1]<2*x[0]) && (2*x[0]<3*x[1]);
}
struct encoder {
    virtual void encode(std::ostream& os, const item_string& s)=0;
    virtual void encode(std::ostream& os, const item_dict& d)=0;
    virtual void encode(std::ostream& os, const item_vector& l)=0;
    virtual void encode(std::ostream& os, const item_integer& i)=0;
    virtual void encode(std::ostream& os)=0;

    virtual item::ptr decode(std::istream& is)=0;
};

/*
template<typename T> 
class threewaycomparable {
    virtual T& get()=0;
    int threewaycompare(const T& b) const
    {
        int c= item::threewaycompare(b);
        if (c)
            return c;
        if (get()<b.get())
            return -1;
        if (get()>b.get())
            return 1;
        return 0;
    }
};
*/

struct bstring : item  /* , threewaycomparable<item_string> */ {
    item_string _s;

    explicit bstring(const item_string& s, const std::string& data) : item(data), _s(s) { }

    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _s);
    }

    item_string& get() { return _s; }
    virtual int type() const { return T_STR; }
};

struct dict : item  /* , threewaycomparable<item_dict> */ {
    item_dict _d;

    dict() { }
    explicit dict(const item_dict& d, const std::string& data) : item(data), _d(d) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _d);
    }

    item_dict& get() { return _d; }
    virtual int type() const { return T_DICT; }

    item::ptr& operator[](const std::string& key)
    {
        return _d[item::ptr(new bstring(key, ""))];
    }
};

struct list : item  /* , threewaycomparable<item_vector> */ {
    item_vector _l;

    list() { }
    explicit list(const item_vector& l, const std::string& data) : item(data), _l(l) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _l);
    }

    item_vector& get() { return _l; }
    
    item::ptr& operator[](unsigned i) {
        if (i>=_l.size())
            _l.resize(i+1);
        return _l[i];
    }
    size_t size() const { return _l.size(); }

    virtual int type() const { return T_LIST; }
};

struct integer : item  /* , threewaycomparable<item_integer> */ {
    item_integer _i;


    explicit integer(item_integer i, const std::string& data) : item(data), _i(i) { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os, _i);
    }
    item_integer& get() { return _i; }
    virtual int type() const { return T_INT; }
};
struct undef : item {

    explicit undef() { }
    virtual void encode(std::ostream& os, encoder* e) const
    {
        e->encode(os);
    }
    void get() { }
    virtual int type() const { return T_UNDEF; }
};

bool operator<(item::ptr a,item::ptr b)
{
    if (a->type() < b->type())
        return true;
    if (a->type() > b->type())
        return false;
    switch(a->type()) {
        case item::T_INT: return a->asinteger()->get() < b->asinteger()->get();
        case item::T_STR: return a->asstring()->get() < b->asstring()->get();
    }
    return false;
}

// get the node in the tree as specified by the path
// <sub>.<sub>.<sub>
// where <sub> can be either a number, or a string depending on where in the tree we are.
item::ptr getnode(item::ptr tree, const std::string& path)
{
    size_t off= 0;

    while (1)
    {
        size_t dot= path.find('.', off);

        if (tree->type()==item::T_LIST) {
            char *p=0;
            int idx= strtoul(&path[off], &p, 0);
            if (p==&path[off])
                throw "path: expected integer";
            tree= tree->aslist()->get()[idx];
        }
        else if (tree->type()==item::T_DICT) {
            tree= (*tree->asdict())[path.substr(off, dot)];
        }
        else {
            throw "path: no dict or list";
        }

        if (dot==std::string::npos)
            break;
        off= dot+1;
    }
    return tree;
}

struct bencoder : encoder {

    static void encode_integer(std::ostream& os, int64_t value)
    {
        if (value==0) {
            os << '0';
            return;
        }
        if (value<0) {
            os << '-';
            value=-value;
        }
        item_string num;
        while (value)
        {
            num += '0'+(value % 10);
            value /= 10;
        }
        std::copy(num.rbegin(), num.rend(), std::ostream_iterator<char>(os));
    }
    static int64_t decode_integer(std::istream& is, std::string& data)
    {
        int64_t value=0;
        bool negative=false;
        char c= is.get();
        if (c=='-') {
            data += c;
            negative=true;
            c= is.get();
        }
        while (c>='0' && c<='9')
        {
            data += c;
            value *= 10;
            value += c-'0';
            c= is.get();
        }
        is.unget();

        return negative ? -value : value;
    }

    virtual void encode(std::ostream& os, const item_string& s)
    {
        encode_integer(os, s.size());
        os << ':';
        copy(s.begin(), s.end(), std::ostream_iterator<char>(os));
    }

    item::ptr decodestring(std::istream& is)
    {
        std::string data;
        int64_t length= decode_integer(is, data);
        if (is.get() != ':')
            throw "invalid benc#1";
        data += ':';
        item_string s;
        s.resize(length);
        is.read(&s[0], length);
        if (is.gcount()!=length)
            throw "invalid benc#2";

        data += s;
        return item::ptr(new bstring(s, data));
    }
    virtual void encode(std::ostream& os, const item_dict& d)
    {
        os << 'd';
        for (item_dict::const_iterator v= d.begin() ; v!=d.end() ; v++)
        {
            (*v).first->encode(os, this);
            (*v).second->encode(os, this);
        }
        os << 'e';
    }

    item::ptr decodedict(std::istream& is)
    {
        std::string data;
        item_dict d;
        if (is.get() != 'd')
            throw "invalid benc#3";
        data += 'd';
        while (is.peek()!='e') {
            item::ptr key= decode(is);
            data += key->data();
            item::ptr value= decode(is);
            data += value->data();
            d.insert(item_dict::value_type(key, value));
        }
        if (is.get() != 'e')
            throw "invalid benc#4";
        data += 'e';
        return item::ptr(new dict(d, data));
    }
    virtual void encode(std::ostream& os, const item_vector& l)
    {
        os << 'l';
        for (item_vector::const_iterator v= l.begin() ; v!=l.end() ; v++)
        {
            (*v)->encode(os, this);
        }
        os << 'e';
    }

    item::ptr decodelist(std::istream& is)
    {
        std::string data;
        item_vector l;
        if (is.get() != 'l')
            throw "invalid benc#5";
        data += 'l';
        while (is.peek()!='e') {
            item::ptr value= decode(is);
            data += value->data();
            l.push_back(value);
        }
        if (is.get() != 'e')
            throw "invalid benc#6";
        data += 'e';
        return item::ptr(new list(l, data));
    }
    virtual void encode(std::ostream& os, const item_integer& i)
    {
        os << 'i';
        encode_integer(os, i);
        os << 'e';
    }


    item::ptr decodeinteger(std::istream& is)
    {
        std::string data;
        if (is.get() != 'i')
            throw "invalid benc#7";
        data += 'i';
        item_integer i= decode_integer(is, data);
        if (is.get() != 'e')
            throw "invalid benc#8";
        data += 'e';
        return item::ptr(new integer(i, data));
    }
    virtual void encode(std::ostream& os)
    {
        throw "benc:null unsupported";
    }

    item::ptr decodeundef(std::istream& is)
    {
        throw "benc-undef unsupported";
    }

    item::ptr decode(std::istream& is)
    {
        switch(is.peek())
        {
            case -1: throw "benc#eof";
            case 'i': return decodeinteger(is);
            case 'd': return decodedict(is);
            case 'l': return decodelist(is);
            default:
                      if (is.peek()>='0' && is.peek()<='9') {
                          return decodestring(is);
                      }
                      throw "invalid benc#9";
        }
    }
};
struct printencoder : encoder {
    int level;
    printencoder() : level(0) { }

    void indent(std::ostream& os)
    {
        for (int i=0 ; i<level ; i++)
            os << "  ";
    }
    virtual void encode(std::ostream& os, const item_string& s)
    {
        if (isbinary(s)) {
            for (unsigned i=0 ; i<s.length() ; i++) {
                if (i) printf(",");
                printf("%02x", s[i]&0xff);
            }
        }
        else {
            os << "\"" << s << "\"";
        }
    }

    virtual void encode(std::ostream& os, const item_dict& d)
    {
        os << "{\n";
        level++;
        for (item_dict::const_iterator v= d.begin() ; v!=d.end() ; v++)
        {
            indent(os);
            (*v).first->encode(os, this);
            os << " => ";
            (*v).second->encode(os, this);
            os << "\n";
        }
        level--;
        indent(os);
        os << "}";
    }

    virtual void encode(std::ostream& os, const item_vector& l)
    {
        os << "[\n";
        level++;
        for (item_vector::const_iterator v= l.begin() ; v!=l.end() ; v++)
        {
            indent(os);
            (*v)->encode(os, this);
            os << "\n";
        }
        level--;
        indent(os);
        os << "]";
    }

    virtual void encode(std::ostream& os, const item_integer& i)
    {
        os << i;
    }

    virtual void encode(std::ostream& os)
    {
        os << "(null)";
    }
    virtual item::ptr decode(std::istream& is)
    {
            return item::ptr(new undef());
    }
};


void hexdump(std::istream& is)
{
    int c;
    while(1)
    {
        c=is.get();
        if (c==-1)
            break;
        printf(" %02x", uint8_t(c)&0xff);
    }
    printf("\n");
}

char nyblechar(int n)
{
    if (n<10) return '0'+n;
    return 'a'+n-10;
}
void hexdump(const uint8_t *p, size_t n, std::ostream& ssout)
{
    while (n--)
    {
        ssout << ' ';
        ssout << nyblechar(*p>>4);
        ssout << nyblechar(*p&15);
        p++;
    }
}
bool dump_torrent(std::istream& ssin, std::ostream& ssout)
{
    try {

    encoder *dec= new bencoder();
    item::ptr tree= dec->decode(ssin);

    dict& d= *tree->asdict();

    item::ptr info= d["info"];

    SHA_CTX ctx;
    SHA1_Init(&ctx);
    SHA1_Update(&ctx, info->data().c_str(), info->data().size());
    uint8_t hash[SHA_DIGEST_LENGTH];
    SHA1_Final(hash, &ctx);
    //ssout << "infodata: "; hexdump((const uint8_t*)info->data().c_str(), info->data().size(), ssout); ssout << "\n";
    ssout << "info_hash: "; hexdump(hash, SHA_DIGEST_LENGTH, ssout); ssout << "\n";

    item::ptr name= (*info->asdict())["name"];
    ssout << "name: " << name->asstring()->get() << "\n";

    item::ptr announce= d["announce"];
    ssout << "announce: " << announce->asstring()->get() << "\n";

    item::ptr announcelist= d["announce-list"];

    if (announcelist)
    {
        list& l= *announcelist->aslist();
        StringList urls;
        for (unsigned i= 0 ; i<l.size() ; i++)
        {
            item::ptr lp= l[i];
            list& ll= *lp->aslist();
            item::ptr lpp= ll[0];
            urls.push_back(lpp->asstring()->get());
            
            printf(" %s\n", urls.back().c_str());
        }
    }

    if (verbose) {
        encoder *enc= new printencoder();
        tree->encode(ssout, enc);
        printf("\n");
    }
    }
    catch (const char*msg)
    {
        printf("ERROR: %s\n", msg);
        return false;
    }
    return true;
}

std::string readfile(const std::string& name)
{
    bool isstdin= name=="-";
    FILE *f= isstdin ? stdin : fopen(name.c_str(), "rb");
    if (f==NULL) {
        perror(name.c_str());
        return "";
    }
    std::string data;
    while (!feof(f))
    {
        data.resize(data.size()+0x10000);
        size_t n= fread(&data[data.size()-0x10000], 1, 0x10000, f);
        data.resize(data.size()-0x10000+n);
        if (n==0)
            break;
    }
    if (!isstdin)
        fclose(f);
    return data;
}
int main(int argc, char**argv)
{
    StringList files;
    for (int i=1 ; i<argc ; i++)
    {
        if (argv[i][0]=='-') switch(argv[i][1]) {
            case 'v': verbose++; break;
        }
        else
            files.push_back(argv[i]);
    }
    for (unsigned i=0; i<files.size() ; i++) {
        if (files.size()>1)
            printf("==> %s <==\n", files[i].c_str());
        std::string data= readfile(files[i]);
        std::stringstream ss(data);
        dump_torrent(ss, std::cout);
    }
}

