diff --git a/src/dht/dht_bucket.cc b/src/dht/dht_bucket.cc index 04aa475..819f4aa 100644 --- a/src/dht/dht_bucket.cc +++ b/src/dht/dht_bucket.cc @@ -52,6 +52,8 @@ DhtBucket::DhtBucket(const HashString& begin, const HashString& end) : m_good(0), m_bad(0), + m_fullCacheLength(0), + m_begin(begin), m_end(end) { @@ -67,6 +69,8 @@ DhtBucket::add_node(DhtNode* n) { m_good++; else if (n->is_bad()) m_bad++; + + m_fullCacheLength = 0; } void @@ -81,6 +85,8 @@ DhtBucket::remove_node(DhtNode* n) { m_good--; else if (n->is_bad()) m_bad--; + + m_fullCacheLength = 0; } void @@ -92,9 +98,11 @@ DhtBucket::count() { // Called every 15 minutes for housekeeping. void DhtBucket::update() { - // For now we only update the counts after some nodes have become bad - // due to prolonged inactivity. count(); + + // In case adjacent buckets whose nodes we borrowed have changed, + // we force an update of the cache. + m_fullCacheLength = 0; } DhtBucket::iterator @@ -188,4 +196,23 @@ DhtBucket::split(const HashString& id) { return other; } +void +DhtBucket::build_full_cache() { + DhtBucketChain chain(this); + + char* pos = m_fullCache; + do { + for (const_iterator itr = chain.bucket()->begin(); itr != chain.bucket()->end() && pos < m_fullCache + sizeof(m_fullCache); ++itr) { + if (!(*itr)->is_bad()) { + pos = (*itr)->store_compact(pos); + + if (pos > m_fullCache + sizeof(m_fullCache)) + throw internal_error("DhtRouter::store_closest_nodes wrote past buffer end."); + } + } + } while (pos < m_fullCache + sizeof(m_fullCache) && chain.next() != NULL); + + m_fullCacheLength = pos - m_fullCache; +} + } diff --git a/src/dht/dht_bucket.h b/src/dht/dht_bucket.h index 97622a3..bcdfd67 100644 --- a/src/dht/dht_bucket.h +++ b/src/dht/dht_bucket.h @@ -111,6 +111,10 @@ public: DhtBucket* parent() const { return m_parent; } DhtBucket* child() const { return m_child; } + // Return a full bucket's worth of compact node data. If this bucket is not + // full, it uses nodes from the child/parent buckets until we have enough. + SimpleString full_bucket(); + // Called by the DhtNode on its bucket to update good/bad node counts. void node_now_good(bool was_bad); void node_now_bad(bool was_good); @@ -118,6 +122,8 @@ public: private: void count(); + void build_full_cache(); + DhtBucket* m_parent; DhtBucket* m_child; @@ -126,11 +132,15 @@ private: unsigned int m_good; unsigned int m_bad; + size_t m_fullCacheLength; + // These are 40 bytes together, so might as well put them last. // m_end is const because it is used as key for the DhtRouter routing table // map, which would be inconsistent if m_end were changed carelessly. HashString m_begin; const HashString m_end; + + char m_fullCache[num_nodes * 26]; }; // Helper class to recursively follow a chain of buckets. It first recurses @@ -160,6 +170,13 @@ DhtBucket::node_now_bad(bool was_good) { m_bad++; } +inline SimpleString +DhtBucket::full_bucket() { + if (!m_fullCacheLength) + build_full_cache(); + return SimpleString(m_fullCache, m_fullCacheLength); +} + inline const DhtBucket* DhtBucketChain::next() { // m_restart is clear when we're done recursing into the children and diff --git a/src/dht/dht_node.cc b/src/dht/dht_node.cc index 3574807..9d51a28 100644 --- a/src/dht/dht_node.cc +++ b/src/dht/dht_node.cc @@ -59,8 +59,8 @@ DhtNode::DhtNode(const HashString& id, const rak::socket_address* sa) : throw resource_error("Address not af_inet"); } -DhtNode::DhtNode(const std::string& id, const Object& cache) : - HashString(*HashString::cast_from(id.c_str())), +DhtNode::DhtNode(const SimpleString& id, const Object& cache) : + HashString(*HashString::cast_from(id)), m_recentlyActive(false), m_recentlyInactive(0), m_bucket(NULL) { diff --git a/src/dht/dht_node.h b/src/dht/dht_node.h index 032c5cc..8234add 100644 --- a/src/dht/dht_node.h +++ b/src/dht/dht_node.h @@ -57,7 +57,7 @@ public: static const unsigned int max_failed_replies = 5; DhtNode(const HashString& id, const rak::socket_address* sa); - DhtNode(const std::string& id, const Object& cache); + DhtNode(const SimpleString& id, const Object& cache); const HashString& id() const { return *this; } const rak::socket_address* address() const { return &m_socketAddress; } diff --git a/src/dht/dht_router.cc b/src/dht/dht_router.cc index b1c95c3..e9abe5d 100644 --- a/src/dht/dht_router.cc +++ b/src/dht/dht_router.cc @@ -329,24 +329,6 @@ DhtRouter::node_invalid(const HashString& id) { delete_node(m_nodes.find(&node->id())); } -char* -DhtRouter::store_closest_nodes(const HashString& id, char* buffer, char* bufferEnd) { - DhtBucketChain chain(find_bucket(id)->second); - - do { - for (DhtBucket::const_iterator itr = chain.bucket()->begin(); itr != chain.bucket()->end() && buffer != bufferEnd; ++itr) { - if (!(*itr)->is_bad()) { - buffer = (*itr)->store_compact(buffer); - - if (buffer > bufferEnd) - throw internal_error("DhtRouter::store_closest_nodes wrote past buffer end."); - } - } - } while (buffer != bufferEnd && chain.next() != NULL); - - return buffer; -} - Object* DhtRouter::store_cache(Object* container) const { container->insert_key("self_id", str()); @@ -355,7 +337,7 @@ DhtRouter::store_cache(Object* container) const { Object& nodes = container->insert_key("nodes", Object::create_map()); for (DhtNodeList::const_accessor itr = m_nodes.begin(); itr != m_nodes.end(); ++itr) { if (!itr.node()->is_bad()) - itr.node()->store_cache(&nodes.insert_key(itr.id().str(), Object::create_map())); + itr.node()->store_cache(&nodes.insert_key(itr.id().s_str(), Object::create_map())); } // Insert contacts, if we have any. @@ -470,7 +452,7 @@ DhtRouter::receive_timeout() { for (DhtBucketList::const_iterator itr = m_routingTable.begin(); itr != m_routingTable.end(); ++itr) { itr->second->update(); - if (!itr->second->is_full() || itr->second->age() > timeout_bucket_bootstrap) + if (!itr->second->is_full() || itr->second == bucket() || itr->second->age() > timeout_bucket_bootstrap) bootstrap_bucket(itr->second); } @@ -505,15 +487,13 @@ DhtRouter::generate_token(const rak::socket_address* sa, int token, char buffer[ return buffer; } -std::string -DhtRouter::make_token(const rak::socket_address* sa) { - char token[20]; - - return std::string(generate_token(sa, m_curToken, token), size_token); +SimpleString +DhtRouter::make_token(const rak::socket_address* sa, char* buffer) { + return SimpleString(generate_token(sa, m_curToken, buffer), size_token); } bool -DhtRouter::token_valid(const std::string& token, const rak::socket_address* sa) { +DhtRouter::token_valid(SimpleString token, const rak::socket_address* sa) { if (token.length() != size_token) return false; @@ -521,12 +501,12 @@ DhtRouter::token_valid(const std::string& token, const rak::socket_address* sa) char reference[20]; // First try current token. - if (std::memcmp(generate_token(sa, m_curToken, reference), token.c_str(), size_token) == 0) + if (token == SimpleString(generate_token(sa, m_curToken, reference), size_token)) return true; // If token recently changed, some clients may be using the older one. // That way a token is valid for 15-30 minutes, instead of 0-15. - return std::memcmp(generate_token(sa, m_prevToken, reference), token.c_str(), size_token) == 0; + return token == SimpleString(generate_token(sa, m_prevToken, reference), size_token); } DhtNode* diff --git a/src/dht/dht_router.h b/src/dht/dht_router.h index f2b673f..816747f 100644 --- a/src/dht/dht_router.h +++ b/src/dht/dht_router.h @@ -115,14 +115,14 @@ public: // Store compact node information (26 bytes) for nodes closest to the // given ID in the given buffer, return new buffer end. - char* store_closest_nodes(const HashString& id, char* buffer, char* bufferEnd); + SimpleString get_closest_nodes(const HashString& id) { return find_bucket(id)->second->full_bucket(); } // Store DHT cache in the given container. Object* store_cache(Object* container) const; // Create and verify a token. Tokens are valid between 15-30 minutes from creation. - std::string make_token(const rak::socket_address* sa); - bool token_valid(const std::string& token, const rak::socket_address* sa); + SimpleString make_token(const rak::socket_address* sa, char* buffer); + bool token_valid(SimpleString token, const rak::socket_address* sa); DhtManager::statistics_type get_statistics() const; void reset_statistics() { m_server.reset_statistics(); } @@ -147,6 +147,8 @@ private: bool add_node_to_bucket(DhtNode* node); void delete_node(const DhtNodeList::accessor& itr); + void store_closest_nodes(const HashString& id, DhtBucket* bucket); + DhtBucketList::iterator split_bucket(const DhtBucketList::iterator& itr, DhtNode* node); void bootstrap(); diff --git a/src/dht/dht_server.cc b/src/dht/dht_server.cc index 1f2234b..256b92b 100644 --- a/src/dht/dht_server.cc +++ b/src/dht/dht_server.cc @@ -38,13 +38,14 @@ #include "globals.h" #include -#include +#include #include "torrent/exceptions.h" #include "torrent/connection_manager.h" #include "torrent/object.h" #include "torrent/object_stream.h" #include "torrent/poll.h" +#include "torrent/static_map.h" #include "torrent/throttle.h" #include "tracker/tracker_dht.h" @@ -63,6 +64,34 @@ const char* DhtServer::queries[] = { "announce_peer", }; +// List of all possible keys we need/support in a DHT message. +// Unsupported keys we receive are dropped (ignored) while decoding. +// See torrent/static_map.h for how this works. +DhtMessage::mapping_type dht_key_names[DhtMessage::length] = { + { key_a_id, "a::id" }, + { key_a_infoHash, "a::info_hash" }, + { key_a_port, "a::port", }, + { key_a_target, "a::target" }, + { key_a_token, "a::token" }, + + { key_e_0, "e[0]" }, + { key_e_1, "e[1]" }, + + { key_q, "q" }, + + { key_r_id, "r::id" }, + { key_r_nodes, "r::nodes" }, + { key_r_token, "r::token" }, + { key_r_values, "r::values[]" }, + + { key_t, "t::" }, + { key_v, "v" }, + { key_y, "y" }, +}; + +template<> +const DhtMessage::key_map_init DhtMessage::base_type::keyMap(dht_key_names); + // Error in DHT protocol, avoids std::string ctor from communication_error class dht_error : public network_error { public: @@ -238,54 +267,51 @@ DhtServer::update() { } void -DhtServer::process_query(const Object& transactionId, const HashString& id, const rak::socket_address* sa, Object& request) { +DhtServer::process_query(const HashString& id, const rak::socket_address* sa, const DhtMessage& msg) { m_queriesReceived++; m_networkUp = true; - std::string& query = request.get_key_string("q"); - - Object& arg = request.get_key("a"); + SimpleString query = msg[key_q].as_sstring(); // Construct reply. - Object reply = Object::create_map(); + DhtMessage reply; if (query == "find_node") - create_find_node_response(arg, reply); + create_find_node_response(msg, reply); else if (query == "get_peers") - create_get_peers_response(arg, sa, reply); + create_get_peers_response(msg, sa, reply); else if (query == "announce_peer") - create_announce_peer_response(arg, sa, reply); + create_announce_peer_response(msg, sa, reply); else if (query != "ping") throw dht_error(dht_error_bad_method, "Unknown query type."); m_router->node_queried(id, sa); - create_response(transactionId, sa, reply); + create_response(msg, sa, reply); } void -DhtServer::create_find_node_response(const Object& arg, Object& reply) { - const std::string& target = arg.get_key_string("target"); +DhtServer::create_find_node_response(const DhtMessage& req, DhtMessage& reply) { + SimpleString target = req[key_a_target].as_sstring(); if (target.length() < HashString::size_data) throw dht_error(dht_error_protocol, "target string too short"); - char compact[sizeof(compact_node_info) * DhtBucket::num_nodes]; - char* end = m_router->store_closest_nodes(*HashString::cast_from(target), compact, compact + sizeof(compact)); - - if (end == compact) + SimpleString nodes = m_router->get_closest_nodes(*HashString::cast_from(target)); + if (nodes.empty()) throw dht_error(dht_error_generic, "No nodes"); - reply.insert_key("nodes", std::string(compact, end)); + reply[key_r_nodes] = nodes; } void -DhtServer::create_get_peers_response(const Object& arg, const rak::socket_address* sa, Object& reply) { - reply.insert_key("token", m_router->make_token(sa)); +DhtServer::create_get_peers_response(const DhtMessage& req, const rak::socket_address* sa, DhtMessage& reply) { + reply[key_r_token] = m_router->make_token(sa, reply.data_end); + reply.data_end += reply[key_r_token].as_sstring().length(); - const std::string& info_hash_str = arg.get_key_string("info_hash"); + SimpleString info_hash_str = req[key_a_infoHash].as_sstring(); if (info_hash_str.length() < HashString::size_data) throw dht_error(dht_error_protocol, "info hash too short"); @@ -296,35 +322,34 @@ DhtServer::create_get_peers_response(const Object& arg, const rak::socket_addres // If we're not tracking or have no peers, send closest nodes. if (!tracker || tracker->empty()) { - char compact[sizeof(compact_node_info) * DhtBucket::num_nodes]; - char* end = m_router->store_closest_nodes(*info_hash, compact, compact + sizeof(compact)); - - if (end == compact) + SimpleString nodes = m_router->get_closest_nodes(*info_hash); + if (nodes.empty()) throw dht_error(dht_error_generic, "No peers nor nodes"); - reply.insert_key("nodes", std::string(compact, end)); + reply[key_r_nodes] = nodes; } else { - reply.insert_key("values", Object::create_list()).as_list().swap(tracker->get_peers().as_list()); + reply[key_r_values] = tracker->get_peers(); } } void -DhtServer::create_announce_peer_response(const Object& arg, const rak::socket_address* sa, Object& reply) { - const std::string& info_hash = arg.get_key_string("info_hash"); +DhtServer::create_announce_peer_response(const DhtMessage& req, const rak::socket_address* sa, DhtMessage& reply) { + SimpleString info_hash = req[key_a_infoHash].as_sstring(); if (info_hash.length() < HashString::size_data) throw dht_error(dht_error_protocol, "info hash too short"); - if (!m_router->token_valid(arg.get_key_string("token"), sa)) + if (!m_router->token_valid(req[key_a_token].as_sstring(), sa)) throw dht_error(dht_error_protocol, "Token invalid."); DhtTracker* tracker = m_router->get_tracker(*HashString::cast_from(info_hash), true); - tracker->add_peer(sa->sa_inet()->address_n(), arg.get_key_value("port")); + tracker->add_peer(sa->sa_inet()->address_n(), req[key_a_port].as_value()); } void -DhtServer::process_response(int transactionId, const HashString& id, const rak::socket_address* sa, Object& request) { +DhtServer::process_response(const HashString& id, const rak::socket_address* sa, const DhtMessage& response) { + int transactionId = (unsigned char)response[key_t].as_sstring()[2]; transaction_itr itr = m_transactions.find(DhtTransaction::key(sa, transactionId)); // Response to a transaction we don't have in our table. At this point it's @@ -351,11 +376,9 @@ DhtServer::process_response(int transactionId, const HashString& id, const rak:: if ((id != transaction->id() && transaction->id() != m_router->zero_id)) return; - const Object& response = request.get_key("r"); - switch (transaction->type()) { case DhtTransaction::DHT_FIND_NODE: - parse_find_node_reply(transaction->as_find_node(), response.get_key_string("nodes")); + parse_find_node_reply(transaction->as_find_node(), response[key_r_nodes].as_sstring()); break; case DhtTransaction::DHT_GET_PEERS: @@ -381,7 +404,8 @@ DhtServer::process_response(int transactionId, const HashString& id, const rak:: } void -DhtServer::process_error(int transactionId, const rak::socket_address* sa, Object& request) { +DhtServer::process_error(const rak::socket_address* sa, const DhtMessage& error) { + int transactionId = (unsigned char)error[key_t].as_sstring()[2]; transaction_itr itr = m_transactions.find(DhtTransaction::key(sa, transactionId)); if (itr == m_transactions.end()) @@ -399,7 +423,7 @@ DhtServer::process_error(int transactionId, const rak::socket_address* sa, Objec } void -DhtServer::parse_find_node_reply(DhtTransactionSearch* transaction, const std::string& nodes) { +DhtServer::parse_find_node_reply(DhtTransactionSearch* transaction, SimpleString nodes) { transaction->complete(true); if (sizeof(const compact_node_info) != 26) @@ -421,16 +445,16 @@ DhtServer::parse_find_node_reply(DhtTransactionSearch* transaction, const std::s } void -DhtServer::parse_get_peers_reply(DhtTransactionGetPeers* transaction, const Object& response) { +DhtServer::parse_get_peers_reply(DhtTransactionGetPeers* transaction, const DhtMessage& response) { DhtAnnounce* announce = static_cast(transaction->as_search()->search()); transaction->complete(true); - if (response.has_key_list("values")) - announce->receive_peers(response.get_key("values")); + if (response[key_r_values].is_sstring()) + announce->receive_peers(response[key_r_values].as_sstring()); - if (response.has_key_string("token")) - add_transaction(new DhtTransactionAnnouncePeer(transaction->id(), transaction->address(), announce->target(), response.get_key_string("token")), packet_prio_low); + if (response[key_r_token].is_sstring()) + add_transaction(new DhtTransactionAnnouncePeer(transaction->id(), transaction->address(), announce->target(), response[key_r_token].as_sstring()), packet_prio_low); announce->update_status(); } @@ -490,17 +514,19 @@ DhtServer::create_query(transaction_itr itr, int tID, const rak::socket_address* if (itr->second->id() == m_router->id()) throw internal_error("DhtServer::create_query trying to send to itself."); - Object query = Object::create_map(); + DhtMessage query; - DhtTransaction* transaction = itr->second; - char trans_id = tID; - query.insert_key("t", std::string(&trans_id, 1)); - query.insert_key("y", "q"); - query.insert_key("q", queries[transaction->type()]); - query.insert_key("v", PEER_VERSION); + // Transaction ID is a bencode string. + query[key_t] = SimpleString(query.data_end, 3); + *query.data_end++ = '1'; + *query.data_end++ = ':'; + *query.data_end++ = tID; - Object& q = query.insert_key("a", Object::create_map()); - q.insert_key("id", m_router->str()); + DhtTransaction* transaction = itr->second; + query[key_y] = SimpleString("q", 1); + query[key_q] = SimpleString(queries[transaction->type()]); + query[key_v] = SimpleString(PEER_VERSION, 4); + query[key_a_id] = m_router->s_str(); switch (transaction->type()) { case DhtTransaction::DHT_PING: @@ -508,17 +534,17 @@ DhtServer::create_query(transaction_itr itr, int tID, const rak::socket_address* break; case DhtTransaction::DHT_FIND_NODE: - q.insert_key("target", transaction->as_find_node()->search()->target().str()); + query[key_a_target] = transaction->as_find_node()->search()->target().s_str(); break; case DhtTransaction::DHT_GET_PEERS: - q.insert_key("info_hash", transaction->as_get_peers()->search()->target().str()); + query[key_a_infoHash] = transaction->as_get_peers()->search()->target().s_str(); break; case DhtTransaction::DHT_ANNOUNCE_PEER: - q.insert_key("info_hash", transaction->as_announce_peer()->info_hash().str()); - q.insert_key("token", transaction->as_announce_peer()->token()); - q.insert_key("port", manager->connection_manager()->listen_port()); + query[key_a_infoHash] = transaction->as_announce_peer()->info_hash().s_str(); + query[key_a_token] = transaction->as_announce_peer()->token(); + query[key_a_port] = manager->connection_manager()->listen_port(); break; } @@ -530,31 +556,26 @@ DhtServer::create_query(transaction_itr itr, int tID, const rak::socket_address* } void -DhtServer::create_response(const Object& transactionId, const rak::socket_address* sa, Object& r) { - Object reply = Object::create_map(); - r.insert_key("id", m_router->str()); - - reply.insert_key("t", transactionId); - reply.insert_key("y", "r"); - reply.insert_key("r", r); - reply.insert_key("v", PEER_VERSION); +DhtServer::create_response(const DhtMessage& req, const rak::socket_address* sa, DhtMessage& reply) { + reply[key_r_id] = m_router->s_str(); + reply[key_t] = req[key_t]; + reply[key_y] = SimpleString("r", 1); + reply[key_v] = SimpleString(PEER_VERSION, 4); add_packet(new DhtTransactionPacket(sa, reply), packet_prio_reply); } void -DhtServer::create_error(const Object* transactionId, const rak::socket_address* sa, int num, const std::string& msg) { - Object error = Object::create_map(); +DhtServer::create_error(const DhtMessage& req, const rak::socket_address* sa, int num, const char* msg) { + DhtMessage error; - if (transactionId != NULL) - error.insert_key("t", *transactionId); + if (req[key_t].is_sstring()) + error[key_t] = req[key_t]; - error.insert_key("y", "e"); - error.insert_key("v", PEER_VERSION); - - Object& e = error.insert_key("e", Object::create_list()); - e.insert_back(num); - e.insert_back(msg); + error[key_y] = SimpleString("e", 1); + error[key_v] = SimpleString(PEER_VERSION, 4); + error[key_e_0] = num; + error[key_e_1] = SimpleString(msg); add_packet(new DhtTransactionPacket(sa, error), packet_prio_reply); } @@ -656,15 +677,12 @@ DhtServer::clear_transactions() { void DhtServer::event_read() { uint32_t total = 0; - std::istringstream sstream; - - sstream.imbue(std::locale::classic()); while (true) { Object request; rak::socket_address sa; int type = '?'; - const Object* transactionId = NULL; + DhtMessage message; const HashString* nodeId = NULL; try { @@ -675,31 +693,32 @@ DhtServer::event_read() { break; total += read; - sstream.str(std::string(buffer, read)); - - sstream >> request; // If it's not a valid bencode dictionary at all, it's probably not a DHT // packet at all, so we don't throw an error to prevent bounce loops. - if (sstream.fail() || !request.is_map()) + try { + staticMap_read_bencode(buffer, buffer + read, message); + } catch (bencode_error& e) { continue; + } - if (!request.has_key("t")) + if (!message[key_t].is_sstring()) throw dht_error(dht_error_protocol, "No transaction ID"); - transactionId = &request.get_key("t"); - - if (!request.has_key_string("y")) + if (!message[key_y].is_sstring()) throw dht_error(dht_error_protocol, "No message type"); - if (request.get_key_string("y").length() != 1) + if (message[key_y].as_sstring().length() != 1) throw dht_error(dht_error_bad_method, "Unsupported message type"); - type = request.get_key_string("y")[0]; + type = message[key_y].as_sstring()[0]; // Queries and replies have node ID in different dictionaries. if (type == 'r' || type == 'q') { - const std::string& nodeIdStr = request.get_key(type == 'q' ? "a" : "r").get_key_string("id"); + if (!message[type == 'q' ? key_a_id : key_r_id].is_sstring()) + throw dht_error(dht_error_protocol, "Invalid `id' value"); + + SimpleString nodeIdStr = message[type == 'q' ? key_a_id : key_r_id].as_sstring(); if (nodeIdStr.length() < HashString::size_data) throw dht_error(dht_error_protocol, "`id' value too short"); @@ -709,7 +728,8 @@ DhtServer::event_read() { // Sanity check the returned transaction ID. if ((type == 'r' || type == 'e') && - (!transactionId->is_string() || transactionId->as_string().length() != 1)) + (!message[key_t].is_sstring() || message[key_t].as_sstring().length() != 3 + || message[key_t].as_sstring()[0] != '1' || message[key_t].as_sstring()[1] != ':')) throw dht_error(dht_error_protocol, "Invalid transaction ID type/length."); // Stupid broken implementations. @@ -718,15 +738,15 @@ DhtServer::event_read() { switch (type) { case 'q': - process_query(*transactionId, *nodeId, &sa, request); + process_query(*nodeId, &sa, message); break; case 'r': - process_response(((unsigned char*)transactionId->as_string().c_str())[0], *nodeId, &sa, request); + process_response(*nodeId, &sa, message); break; case 'e': - process_error(((unsigned char*)transactionId->as_string().c_str())[0], &sa, request); + process_error(&sa, message); break; default: @@ -737,16 +757,19 @@ DhtServer::event_read() { // so that if it repeatedly sends malformed replies we will drop it instead of propagating it // to other nodes. } catch (bencode_error& e) { - if ((type == 'r' || type == 'e') && nodeId != NULL) + if ((type == 'r' || type == 'e') && nodeId != NULL) { m_router->node_inactive(*nodeId, &sa); - else - create_error(transactionId, &sa, dht_error_protocol, std::string("Malformed packet: ") + e.what()); + } else { + snprintf(message.data_end, message.data + message.data_size - message.data_end - 1, "Malformed packet: %s", e.what()); + message.data[message.data_size - 1] = 0; + create_error(message, &sa, dht_error_protocol, message.data_end); + } } catch (dht_error& e) { if ((type == 'r' || type == 'e') && nodeId != NULL) m_router->node_inactive(*nodeId, &sa); else - create_error(transactionId, &sa, e.code(), e.what()); + create_error(message, &sa, e.code(), e.what()); } catch (network_error& e) { diff --git a/src/dht/dht_server.h b/src/dht/dht_server.h index 1855b73..1f55f15 100644 --- a/src/dht/dht_server.h +++ b/src/dht/dht_server.h @@ -46,6 +46,7 @@ #include "net/throttle_node.h" #include "download/download_info.h" // for SocketAddressCompact #include "torrent/hash_string.h" +#include "torrent/simple_string.h" #include "dht_transaction.h" @@ -56,6 +57,7 @@ class DhtNode; class DhtRouter; class DownloadInfo; +class DhtMessage; class TrackerDht; // UDP server that handles the DHT node communications. @@ -134,23 +136,23 @@ private: void start_write(); - void process_query(const Object& transaction, const HashString& id, const rak::socket_address* sa, Object& req); - void process_response(int transaction, const HashString& id, const rak::socket_address* sa, Object& req); - void process_error(int transaction, const rak::socket_address* sa, Object& req); + void process_query(const HashString& id, const rak::socket_address* sa, const DhtMessage& req); + void process_response(const HashString& id, const rak::socket_address* sa, const DhtMessage& req); + void process_error(const rak::socket_address* sa, const DhtMessage& error); - void parse_find_node_reply(DhtTransactionSearch* t, const std::string& nodes); - void parse_get_peers_reply(DhtTransactionGetPeers* t, const Object& res); + void parse_find_node_reply(DhtTransactionSearch* t, SimpleString res); + void parse_get_peers_reply(DhtTransactionGetPeers* t, const DhtMessage& res); void find_node_next(DhtTransactionSearch* t); void add_packet(DhtTransactionPacket* packet, int priority); void create_query(transaction_itr itr, int tID, const rak::socket_address* sa, int priority); - void create_response(const Object& transactionID, const rak::socket_address* sa, Object& r); - void create_error(const Object* transactionID, const rak::socket_address* sa, int num, const std::string& msg); + void create_response(const DhtMessage& req, const rak::socket_address* sa, DhtMessage& reply); + void create_error(const DhtMessage& req, const rak::socket_address* sa, int num, const char* msg); - void create_find_node_response(const Object& arg, Object& reply); - void create_get_peers_response(const Object& arg, const rak::socket_address* sa, Object& reply); - void create_announce_peer_response(const Object& arg, const rak::socket_address* sa, Object& reply); + void create_find_node_response(const DhtMessage& arg, DhtMessage& reply); + void create_get_peers_response(const DhtMessage& arg, const rak::socket_address* sa, DhtMessage& reply); + void create_announce_peer_response(const DhtMessage& arg, const rak::socket_address* sa, DhtMessage& reply); int add_transaction(DhtTransaction* t, int priority); diff --git a/src/dht/dht_tracker.cc b/src/dht/dht_tracker.cc index 416dbf3..6e1afe9 100644 --- a/src/dht/dht_tracker.cc +++ b/src/dht/dht_tracker.cc @@ -54,8 +54,8 @@ DhtTracker::add_peer(uint32_t addr, uint16_t port) { // Check if peer exists. If not, find oldest peer. for (unsigned int i = 0; i < size(); i++) { - if (m_peers[i].addr == compact.addr) { - m_peers[i].port = compact.port; + if (m_peers[i].peer.addr == compact.addr) { + m_peers[i].peer.port = compact.port; m_lastSeen[i] = cachedTime.seconds(); return; @@ -77,10 +77,13 @@ DhtTracker::add_peer(uint32_t addr, uint16_t port) { } } -// Return compact info (6 bytes) for up to 30 peers, returning different -// peers for each call if there are more. -Object +// Return compact info as bencoded string (8 bytes per peer) for up to 30 peers, +// returning different peers for each call if there are more. +SimpleString DhtTracker::get_peers(unsigned int maxPeers) { + if (sizeof(BencodeAddress) != 8) + throw internal_error("DhtTracker::BencodeAddress is packed incorrectly."); + PeerList::iterator first = m_peers.begin(); PeerList::iterator last = m_peers.end(); @@ -94,11 +97,7 @@ DhtTracker::get_peers(unsigned int maxPeers) { last = first + maxPeers; } - Object peers = Object::create_list(); - for (; first != last; ++first) - peers.insert_back(std::string(first->c_str(), sizeof(*first))); - - return peers; + return SimpleString(first->bencode(), last->bencode() - first->bencode()); } // Remove old announces. @@ -107,9 +106,9 @@ DhtTracker::prune(uint32_t maxAge) { uint32_t minSeen = cachedTime.seconds() - maxAge; for (unsigned int i = 0; i < m_lastSeen.size(); i++) - if (m_lastSeen[i] < minSeen) m_peers[i].port = 0; + if (m_lastSeen[i] < minSeen) m_peers[i].peer.port = 0; - m_peers.erase(std::remove_if(m_peers.begin(), m_peers.end(), rak::on(rak::mem_ref(&SocketAddressCompact::port), std::bind2nd(std::equal_to(), 0))), m_peers.end()); + m_peers.erase(std::remove_if(m_peers.begin(), m_peers.end(), std::mem_fun_ref(&BencodeAddress::empty)), m_peers.end()); m_lastSeen.erase(std::remove_if(m_lastSeen.begin(), m_lastSeen.end(), std::bind2nd(std::less(), minSeen)), m_lastSeen.end()); if (m_peers.size() != m_lastSeen.size()) diff --git a/src/dht/dht_tracker.h b/src/dht/dht_tracker.h index 8515dd0..53fd1e3 100644 --- a/src/dht/dht_tracker.h +++ b/src/dht/dht_tracker.h @@ -43,6 +43,7 @@ #include #include "download/download_info.h" // for SocketAddressCompact +#include "torrent/simple_string.h" namespace torrent { @@ -65,14 +66,26 @@ public: size_t size() const { return m_peers.size(); } void add_peer(uint32_t addr, uint16_t port); - Object get_peers(unsigned int maxPeers = max_peers); + SimpleString get_peers(unsigned int maxPeers = max_peers); // Remove old announces from the tracker that have not reannounced for // more than the given number of seconds. void prune(uint32_t maxAge); private: - typedef std::vector PeerList; + // We need to store the address as a bencoded string. + struct BencodeAddress { + char header[2]; + SocketAddressCompact peer; + + BencodeAddress(const SocketAddressCompact& p) : peer(p) { header[0] = '6'; header[1] = ':'; } + + const char* bencode() const { return header; } + + bool empty() const { return !peer.port; } + } __attribute__ ((packed)); + + typedef std::vector PeerList; PeerList m_peers; std::vector m_lastSeen; diff --git a/src/dht/dht_transaction.cc b/src/dht/dht_transaction.cc index 2a6a8a6..0b3cfd0 100644 --- a/src/dht/dht_transaction.cc +++ b/src/dht/dht_transaction.cc @@ -123,7 +123,7 @@ DhtSearch::trim(bool final) { // We keep: // - the max_contacts=18 closest good or unknown nodes and all nodes closer // than them (to see if further searches find closer ones) - // - for announces, also the 8 closest good nodes (i.e. nodes that have + // - for announces, also the 3 closest good nodes (i.e. nodes that have // replied) to have at least that many for the actual announce // - any node that currently has transactions pending // @@ -136,7 +136,7 @@ DhtSearch::trim(bool final) { // node is new and unknown otherwise int needClosest = final ? 0 : max_contacts; - int needGood = is_announce() ? DhtBucket::num_nodes : 0; + int needGood = is_announce() ? max_announce : 0; // We're done if we can't find any more nodes to contact. m_next = end(); @@ -252,7 +252,7 @@ DhtAnnounce::start_announce() { } void -DhtAnnounce::receive_peers(const Object& peers) { +DhtAnnounce::receive_peers(SimpleString peers) { m_tracker->receive_peers(peers); } @@ -262,9 +262,9 @@ DhtAnnounce::update_status() { } void -DhtTransactionPacket::build_buffer(const Object& data) { +DhtTransactionPacket::build_buffer(const DhtMessage& msg) { char buffer[1500]; // If the message would exceed an Ethernet frame, something went very wrong. - object_buffer_t result = object_write_bencode_c(object_write_to_buffer, NULL, std::make_pair(buffer, buffer + sizeof(buffer)), &data); + object_buffer_t result = staticMap_write_bencode_c(object_write_to_buffer, NULL, std::make_pair(buffer, buffer + sizeof(buffer)), msg); m_length = result.second - buffer; m_data = new char[m_length]; @@ -277,7 +277,6 @@ DhtTransaction::DhtTransaction(int quick_timeout, int timeout, const HashString& m_sa(*sa), m_timeout(cachedTime.seconds() + timeout), m_quickTimeout(cachedTime.seconds() + quick_timeout), - m_retry(3), m_packet(NULL) { } diff --git a/src/dht/dht_transaction.h b/src/dht/dht_transaction.h index 194316d..43b42ab 100644 --- a/src/dht/dht_transaction.h +++ b/src/dht/dht_transaction.h @@ -43,6 +43,7 @@ #include "dht/dht_node.h" #include "torrent/hash_string.h" +#include "torrent/static_map.h" namespace torrent { @@ -93,6 +94,9 @@ public: // Number of closest potential contact nodes to keep. static const unsigned int max_contacts = 18; + // Number of closest nodes we actually announce to. + static const unsigned int max_announce = 3; + DhtSearch(const HashString& target, const DhtBucket& contacts); virtual ~DhtSearch(); @@ -178,22 +182,66 @@ public: // counts announces instead. const_accessor start_announce(); - void receive_peers(const Object& peer_list); + void receive_peers(SimpleString peers); void update_status(); private: TrackerDht* m_tracker; }; +// Possible bencode keys in a DHT message. +enum dht_keys { + key_a_id, + key_a_infoHash, + key_a_port, + key_a_target, + key_a_token, + + key_e_0, + key_e_1, + + key_q, + + key_r_id, + key_r_nodes, + key_r_token, + key_r_values, + + key_t, + key_v, + key_y, + + key_LAST, +}; + +class DhtMessage : public StaticMap { +public: + typedef StaticMap base_type; + typedef StaticMapKeys::mapping_type mapping_type; + + DhtMessage() : data_end(data) {}; + + // Must be big enough to hold one of the possible variable-sized reply data. + // Currently either: + // - error message (size doesn't really matter, it'll be truncated at worst) + // - announce token (8 bytes, needs 20 bytes buffer to build) + // Never more than one of the above. + // And additionally for queries we send: + // - transaction ID (3 bytes) + static const size_t data_size = 64; + char data[data_size]; + char* data_end; +}; + // Class holding transaction data to be transmitted. class DhtTransactionPacket { public: // transaction packet - DhtTransactionPacket(const rak::socket_address* s, const Object& d, unsigned int id, DhtTransaction* t) + DhtTransactionPacket(const rak::socket_address* s, const DhtMessage& d, unsigned int id, DhtTransaction* t) : m_sa(*s), m_id(id), m_transaction(t) { build_buffer(d); }; // non-transaction packet - DhtTransactionPacket(const rak::socket_address* s, const Object& d) + DhtTransactionPacket(const rak::socket_address* s, const DhtMessage& d) : m_sa(*s), m_id(-cachedTime.seconds()), m_transaction(NULL) { build_buffer(d); }; ~DhtTransactionPacket() { delete[] m_data; } @@ -214,7 +262,7 @@ public: DhtTransaction* transaction() { return m_transaction; } private: - void build_buffer(const Object& data); + void build_buffer(const DhtMessage& data); rak::socket_address m_sa; char* m_data; @@ -255,9 +303,6 @@ public: int quick_timeout() { return m_quickTimeout; } bool has_quick_timeout() { return m_hasQuickTimeout; } - int dec_retry() { return m_retry--; } - int retry() { return m_retry; } - DhtTransactionPacket* packet() { return m_packet; } void set_packet(DhtTransactionPacket* p) { m_packet = p; } @@ -282,7 +327,6 @@ private: rak::socket_address m_sa; int m_timeout; int m_quickTimeout; - int m_retry; DhtTransactionPacket* m_packet; }; @@ -337,7 +381,7 @@ public: class DhtTransactionAnnouncePeer : public DhtTransaction { public: - DhtTransactionAnnouncePeer(const HashString& id, const rak::socket_address* sa, const HashString& infoHash, const std::string& token) + DhtTransactionAnnouncePeer(const HashString& id, const rak::socket_address* sa, const HashString& infoHash, SimpleString token) : DhtTransaction(-1, 30, id, sa), m_infoHash(infoHash), m_token(token) { } @@ -345,11 +389,11 @@ public: virtual transaction_type type() { return DHT_ANNOUNCE_PEER; } const HashString& info_hash() { return m_infoHash; } - const std::string& token() { return m_token; } + SimpleString token() { return m_token; } private: HashString m_infoHash; - std::string m_token; + SimpleString m_token; }; inline bool diff --git a/src/download/download_constructor.cc b/src/download/download_constructor.cc index fc2a272..86e5351 100644 --- a/src/download/download_constructor.cc +++ b/src/download/download_constructor.cc @@ -36,6 +36,7 @@ #include "config.h" +#include #include #include #include @@ -80,7 +81,10 @@ struct download_constructor_encoding_match : }; void -DownloadConstructor::initialize(const Object& b) { +DownloadConstructor::initialize(Object& b) { + if (!b.has_key_map("info") && b.has_key_string("magnet-uri")) + parse_magnet_uri(b, b.get_key_string("magnet-uri")); + if (b.has_key_string("encoding")) m_defaultEncoding = b.get_key_string("encoding"); @@ -135,10 +139,24 @@ DownloadConstructor::parse_info(const Object& b) { if (b.flags() & Object::flag_unordered) throw input_error("Download has unordered info dictionary."); - uint32_t chunkSize = b.get_key_value("piece length"); + uint32_t chunkSize; + + if (b.has_key_value("meta_download") && b.get_key_value("meta_download")) + m_download->info()->set_meta_download(true); + + if (m_download->info()->is_meta_download()) { + if (b.get_key_string("pieces").length() != HashString::size_data) + throw input_error("Meta-download has invalid piece data."); + + chunkSize = 1; + parse_single_file(b, chunkSize); + + } else { + chunkSize = b.get_key_value("piece length"); - if (chunkSize <= (1 << 10) || chunkSize > (128 << 20)) - throw input_error("Torrent has an invalid \"piece length\"."); + if (chunkSize <= (1 << 10) || chunkSize > (128 << 20)) + throw input_error("Torrent has an invalid \"piece length\"."); + } if (b.has_key("length")) { parse_single_file(b, chunkSize); @@ -147,11 +165,11 @@ DownloadConstructor::parse_info(const Object& b) { parse_multi_files(b.get_key("files"), chunkSize); fileList->set_root_dir("./" + m_download->info()->name()); - } else { + } else if (!m_download->info()->is_meta_download()) { throw input_error("Torrent must have either length or files entry."); } - if (fileList->size_bytes() == 0) + if (fileList->size_bytes() == 0 && !m_download->info()->is_meta_download()) throw input_error("Torrent has zero length."); // Set chunksize before adding files to make sure the index range is @@ -238,7 +256,7 @@ DownloadConstructor::parse_single_file(const Object& b, uint32_t chunkSize) { throw input_error("Bad torrent file, \"name\" is an invalid path name."); FileList* fileList = m_download->main()->file_list(); - fileList->initialize(b.get_key_value("length"), chunkSize); + fileList->initialize(chunkSize == 1 ? 1 : b.get_key_value("length"), chunkSize); fileList->set_multi_file(false); std::list pathList; @@ -342,4 +360,132 @@ DownloadConstructor::choose_path(std::list* pathList) { return pathList->front(); } +static const char* +parse_base32_sha1(const char* pos, HashString& hash) { + HashString::iterator hashItr = hash.begin(); + + static const int base_shift = 8+8-5; + int shift = base_shift; + uint16_t decoded = 0; + + while (*pos) { + char c = *pos++; + uint16_t value; + + if (c >= 'A' && c <= 'Z') + value = c - 'A'; + else if (c >= 'a' && c <= 'z') + value = c - 'a'; + else if (c >= '2' && c <= '7') + value = 26 + c - '2'; + else if (c == '&') + break; + else + return NULL; + + decoded |= (value << shift); + if (shift <= 8) { + // Too many characters for a base32 SHA1. + if (hashItr == hash.end()) + return NULL; + + *hashItr++ = (decoded >> 8); + decoded <<= 8; + shift += 3; + } else { + shift -= 5; + } + } + + return hashItr != hash.end() || shift != base_shift ? NULL : pos; +} + +void +DownloadConstructor::parse_magnet_uri(Object& b, const std::string& uri) { + if (std::strncmp(uri.c_str(), "magnet:?", 8)) + throw input_error("Invalid magnet URI."); + + const char* pos = uri.c_str() + 8; + + Object trackers(Object::create_list()); + HashString hash; + bool hashValid = false; + + while (*pos) { + const char* tagStart = pos; + while (*pos != '=') + if (!*pos++) + break; + + SimpleString tag(tagStart, pos - tagStart); + pos++; + + // hash may be base32 encoded (optional in BEP 0009 and common practice) + if (tag == "xt") { + if (strncmp(pos, "urn:btih:", 9)) + throw input_error("Invalid magnet URI."); + + pos += 9; + + const char* nextPos = parse_base32_sha1(pos, hash); + if (nextPos != NULL) { + pos = nextPos; + hashValid = true; + continue; + } + } + + // everything else, including sometimes the hash, is url encoded. + std::string decoded; + while (*pos) { + char c = *pos++; + if (c == '%') { + if (sscanf(pos, "%02hhx", &c) != 1) + throw input_error("Invalid magnet URI."); + + pos += 2; + + } else if (c == '&') { + break; + } + + decoded.push_back(c); + } + + if (tag == "xt") { + // url-encoded hash as per magnet URN specs + if (decoded.length() == hash.size_data) { + hash = *HashString::cast_from(decoded); + hashValid = true; + + // hex-encoded hash as per BEP 0009 + } else if (decoded.length() == hash.size_data * 2) { + std::string::iterator hexItr = decoded.begin(); + for (HashString::iterator itr = hash.begin(), last = hash.end(); itr != last; itr++, hexItr += 2) + *itr = (rak::hexchar_to_value(*hexItr) << 4) + rak::hexchar_to_value(*(hexItr + 1)); + hashValid = true; + + } else { + throw input_error("Invalid magnet URI."); + } + } else if (tag == "tr") { + trackers.insert_back(Object::create_list()).insert_back(decoded); + } + // could also handle "dn" = display name (torrent name), but we can't really use that + } + + if (!hashValid) + throw input_error("Invalid magnet URI."); + + Object& info = b.insert_key("info", Object::create_map()); + info.insert_key("pieces", hash.str()); + info.insert_key("name", rak::transform_hex(hash.str()) + ".meta"); + info.insert_key("meta_download", (int64_t)1); + + if (!trackers.as_list().empty()) { + b.insert_preserve_copy("announce", trackers.as_list().begin()->as_list().begin()->as_string()); + b.insert_preserve_type("announce-list", trackers); + } +} + } diff --git a/src/download/download_constructor.h b/src/download/download_constructor.h index 7192f90..8af520f 100644 --- a/src/download/download_constructor.h +++ b/src/download/download_constructor.h @@ -55,7 +55,7 @@ class DownloadConstructor { public: DownloadConstructor() : m_download(NULL), m_encodingList(NULL) {} - void initialize(const Object& b); + void initialize(Object& b); void set_download(DownloadWrapper* d) { m_download = d; } void set_encoding_list(const EncodingList* e) { m_encodingList = e; } @@ -64,6 +64,7 @@ private: void parse_name(const Object& b); void parse_tracker(const Object& b); void parse_info(const Object& b); + void parse_magnet_uri(Object& b, const std::string& uri); void add_tracker_group(const Object& b); void add_tracker_single(const Object& b, int group); diff --git a/src/download/download_info.h b/src/download/download_info.h index 0a3c0e8..68fb178 100644 --- a/src/download/download_info.h +++ b/src/download/download_info.h @@ -76,6 +76,7 @@ public: m_isCompact(true), m_isAcceptingNewPeers(true), m_isPrivate(false), + m_isMetaDownload(false), m_pexEnabled(true), m_pexActive(true), @@ -86,7 +87,8 @@ public: m_uploadedBaseline(0), m_completedBaseline(0), m_sizePex(0), - m_maxSizePex(8) { + m_maxSizePex(8), + m_metadataSize(0) { } const std::string& name() const { return m_name; } @@ -116,6 +118,9 @@ public: bool is_private() const { return m_isPrivate; } void set_private(bool p) { m_isPrivate = p; if (p) m_pexEnabled = false; } + bool is_meta_download() const { return m_isMetaDownload; } + void set_meta_download(bool m) { m_isMetaDownload = m; } + bool is_pex_enabled() const { return m_pexEnabled; } void set_pex_enabled(bool enabled) { m_pexEnabled = enabled && !m_isPrivate; } @@ -134,6 +139,9 @@ public: uint64_t completed_adjusted() const { return std::max(m_slotStatCompleted() - completed_baseline(), 0); } void set_completed_baseline(uint64_t b) { m_completedBaseline = b; } + size_t metadata_size() const { return m_metadataSize; } + void set_metadata_size(size_t size) { m_metadataSize = size; } + uint32_t size_pex() const { return m_sizePex; } void set_size_pex(uint32_t b) { m_sizePex = b; } @@ -165,6 +173,7 @@ private: bool m_isCompact; bool m_isAcceptingNewPeers; bool m_isPrivate; + bool m_isMetaDownload; bool m_pexEnabled; bool m_pexActive; @@ -176,6 +185,7 @@ private: uint64_t m_completedBaseline; uint32_t m_sizePex; uint32_t m_maxSizePex; + size_t m_metadataSize; slot_stat_type m_slotStatCompleted; slot_stat_type m_slotStatLeft; diff --git a/src/download/download_main.cc b/src/download/download_main.cc index 1dd5f98..5691021 100644 --- a/src/download/download_main.cc +++ b/src/download/download_main.cc @@ -455,4 +455,19 @@ DownloadMain::do_peer_exchange() { } } +void +DownloadMain::set_metadata_size(size_t size) { + if (m_info->is_meta_download()) { + if (m_fileList.size_bytes() < 2) + file_list()->reset_filesize(size); + else if (size != m_fileList.size_bytes()) + throw communication_error("Peer-supplied metadata size mismatch."); + + } else if (m_info->metadata_size() && m_info->metadata_size() != size) { + throw communication_error("Peer-supplied metadata size mismatch."); + } + + m_info->set_metadata_size(size); +} + } diff --git a/src/download/download_main.h b/src/download/download_main.h index 5d0090b..700f41e 100644 --- a/src/download/download_main.h +++ b/src/download/download_main.h @@ -116,6 +116,8 @@ public: bool want_pex_msg() { return m_info->is_pex_active() && m_peerList.available_list()->want_more(); }; + void set_metadata_size(size_t s); + // Carefull with these. void setup_delegator(); void setup_tracker(); diff --git a/src/net/address_list.cc b/src/net/address_list.cc index 2fc3992..e5cf3cb 100644 --- a/src/net/address_list.cc +++ b/src/net/address_list.cc @@ -70,7 +70,7 @@ AddressList::parse_address_normal(const Object::list_type& b) { } void -AddressList::parse_address_compact(const std::string& s) { +AddressList::parse_address_compact(SimpleString s) { if (sizeof(const SocketAddressCompact) != 6) throw internal_error("ConnectionList::AddressList::parse_address_compact(...) bad struct size."); @@ -79,4 +79,18 @@ AddressList::parse_address_compact(const std::string& s) { std::back_inserter(*this)); } +void +AddressList::parse_address_bencode(SimpleString s) { + if (sizeof(const SocketAddressCompact) != 6) + throw internal_error("AddressList::parse_address_bencode(...) bad struct size."); + + while (s.length() >= 2 + sizeof(SocketAddressCompact)) { + if (s[0] != '6' || s[1] != ':') + break; + + insert(end(), *reinterpret_cast(s.c_str() + 2)); + s = SimpleString(s.c_str() + 2 + sizeof(SocketAddressCompact), s.length() - 2 - sizeof(SocketAddressCompact)); + } +} + } diff --git a/src/net/address_list.h b/src/net/address_list.h index e4d2009..10dbac4 100644 --- a/src/net/address_list.h +++ b/src/net/address_list.h @@ -42,6 +42,7 @@ #include #include +#include namespace torrent { @@ -49,7 +50,8 @@ class AddressList : public std::list { public: // Parse normal or compact list of addresses and add to AddressList void parse_address_normal(const Object::list_type& b); - void parse_address_compact(const std::string& s); + void parse_address_compact(SimpleString s); + void parse_address_bencode(SimpleString s); private: static rak::socket_address parse_address(const Object& b); diff --git a/src/net/data_buffer.h b/src/net/data_buffer.h index a26ca36..e3d9e38 100644 --- a/src/net/data_buffer.h +++ b/src/net/data_buffer.h @@ -48,6 +48,7 @@ struct DataBuffer { DataBuffer(char* data, char* end) : m_data(data), m_end(end), m_owned(true) {} DataBuffer clone() const { DataBuffer d = *this; d.m_owned = false; return d; } + DataBuffer release() { DataBuffer d = *this; set(NULL, NULL, false); return d; } char* data() const { return m_data; } char* end() const { return m_end; } @@ -70,7 +71,7 @@ private: inline void DataBuffer::clear() { - if (!empty()) + if (!empty() && m_owned) delete[] m_data; m_data = m_end = NULL; diff --git a/src/net/socket_base.cc b/src/net/socket_base.cc index 90457dc..13a9c8b 100644 --- a/src/net/socket_base.cc +++ b/src/net/socket_base.cc @@ -47,7 +47,7 @@ namespace torrent { -char* SocketBase::m_nullBuffer = new char[1 << 17]; +char* SocketBase::m_nullBuffer = new char[SocketBase::null_buffer_size]; SocketBase::~SocketBase() { if (get_fd().is_valid()) diff --git a/src/net/socket_base.h b/src/net/socket_base.h index 9340a23..0f0f424 100644 --- a/src/net/socket_base.h +++ b/src/net/socket_base.h @@ -68,6 +68,8 @@ protected: SocketBase(const SocketBase&); void operator = (const SocketBase&); + static const size_t null_buffer_size = 1 << 17; + static char* m_nullBuffer; }; diff --git a/src/protocol/Makefile.am b/src/protocol/Makefile.am index 6171d06..18f671d 100644 --- a/src/protocol/Makefile.am +++ b/src/protocol/Makefile.am @@ -17,6 +17,8 @@ libsub_protocol_la_SOURCES = \ peer_connection_base.h \ peer_connection_leech.cc \ peer_connection_leech.h \ + peer_connection_metadata.cc \ + peer_connection_metadata.h \ peer_factory.cc \ peer_factory.h \ protocol_base.h \ diff --git a/src/protocol/extensions.cc b/src/protocol/extensions.cc index f3464af..3e0cf60 100644 --- a/src/protocol/extensions.cc +++ b/src/protocol/extensions.cc @@ -37,30 +37,103 @@ #include "config.h" #include -#include +#include #include #include "download/available_list.h" #include "download/download_main.h" +#include "download/download_manager.h" +#include "download/download_wrapper.h" #include "protocol/peer_connection_base.h" #include "torrent/connection_manager.h" #include "torrent/object.h" #include "torrent/object_stream.h" #include "torrent/peer/connection_list.h" #include "torrent/peer/peer_info.h" -#include "tracker/tracker_http.h" +#include "torrent/static_map.h" #include "manager.h" #include "extensions.h" namespace torrent { -const char* ProtocolExtension::message_keys[] = { - "HANDSHAKE", - "ut_pex", +enum ext_handshake_keys { + key_e, + key_m_utMetadata, + key_m_utPex, + key_metadataSize, + key_p, + key_reqq, + key_v, + key_handshake_LAST }; +enum ext_pex_keys { + key_pex_added, + key_pex_LAST +}; + +enum ext_metadata_keys { + key_msgType, + key_piece, + key_totalSize, + key_metadata_LAST +}; + +class ExtHandshakeMessage : public StaticMap { +public: + typedef StaticMap base_type; + typedef StaticMapKeys::mapping_type mapping_type; +}; + +class ExtPEXMessage : public StaticMap { +public: + typedef StaticMap base_type; + typedef StaticMapKeys::mapping_type mapping_type; +}; + +class ExtMetadataMessage : public StaticMap { +public: + typedef StaticMap base_type; + typedef StaticMapKeys::mapping_type mapping_type; +}; + +ExtHandshakeMessage::mapping_type ext_handshake_key_names[ExtHandshakeMessage::length] = { + { key_e, "e" }, + { key_m_utMetadata, "m::ut_metadata" }, + { key_m_utPex, "m::ut_pex" }, + { key_metadataSize, "metadata_size" }, + { key_p, "p" }, + { key_reqq, "reqq" }, + { key_v, "v" }, +}; + +ExtPEXMessage::mapping_type ext_pex_key_names[ExtPEXMessage::length] = { + { key_pex_added, "added" }, +}; + +ExtMetadataMessage::mapping_type ext_metadata_key_names[ExtMetadataMessage::length] = { + { key_msgType, "msg_type" }, + { key_piece, "piece" }, + { key_totalSize, "total_size" }, +}; + +ext_handshake_keys message_keys[ProtocolExtension::FIRST_INVALID] = { + key_handshake_LAST, // Handshake, not actually used. + key_m_utPex, + key_m_utMetadata, +}; + +template<> +const ExtHandshakeMessage::key_map_init ExtHandshakeMessage::base_type::keyMap(ext_handshake_key_names); + +template<> +const ExtPEXMessage::key_map_init ExtPEXMessage::base_type::keyMap(ext_pex_key_names); + +template<> +const ExtMetadataMessage::key_map_init ExtMetadataMessage::base_type::keyMap(ext_metadata_key_names); + void ProtocolExtension::cleanup() { // if (is_default()) @@ -105,23 +178,25 @@ ProtocolExtension::unset_local_enabled(int t) { DataBuffer ProtocolExtension::generate_handshake_message() { - Object map = Object::create_map(); - Object message = Object::create_map(); - - map.insert_key(message_keys[UT_PEX], is_local_enabled(UT_PEX) ? 1 : 0); + ExtHandshakeMessage message; // Add "e" key if encryption is enabled, set it to 1 if we require // encryption for incoming connections, or 0 otherwise. if ((manager->connection_manager()->encryption_options() & ConnectionManager::encryption_allow_incoming) != 0) - message.insert_key("e", (manager->connection_manager()->encryption_options() & ConnectionManager::encryption_require) != 0); + message[key_e] = (manager->connection_manager()->encryption_options() & ConnectionManager::encryption_require) != 0; + + message[key_p] = manager->connection_manager()->listen_port(); + message[key_v] = SimpleString("libTorrent " VERSION); + message[key_reqq] = 2048; // maximum request queue size + + if (!m_download->info()->is_meta_download()) + message[key_metadataSize] = m_download->info()->metadata_size(); - message.insert_key("m", map); - message.insert_key("p", manager->connection_manager()->listen_port()); - message.insert_key("v", "libTorrent " VERSION); - message.insert_key("reqq", 2048); // maximum request queue size + message[key_m_utPex] = is_local_enabled(UT_PEX) ? UT_PEX : 0; + message[key_m_utMetadata] = UT_METADATA; char buffer[1024]; - object_buffer_t result = object_write_bencode_c(object_write_to_buffer, NULL, std::make_pair(buffer, buffer + sizeof(buffer)), &message); + object_buffer_t result = staticMap_write_bencode_c(object_write_to_buffer, NULL, std::make_pair(buffer, buffer + sizeof(buffer)), message); int length = result.second - buffer; char* copy = new char[length]; @@ -130,21 +205,30 @@ ProtocolExtension::generate_handshake_message() { return DataBuffer(copy, copy + length); } -DataBuffer -ProtocolExtension::generate_toggle_message(ProtocolExtension::MessageType t, bool on) { - // TODO: Check if we're accepting this message type? +inline DataBuffer +ProtocolExtension::build_bencode(size_t maxLength, const char* format, ...) { + char* b = new char[maxLength]; - // Manually create bencoded map { "m" => { message_keys[t] => on ? t : 0 } } - char* b = new char[32]; - unsigned int length = snprintf(b, 32, "d1:md%zu:%si%deee", strlen(message_keys[t]), message_keys[t], on ? t : 0); + va_list args; + va_start(args, format); + unsigned int length = vsnprintf(b, maxLength, format, args); + va_end(args); - if (length > 32) - throw internal_error("ProtocolExtension::toggle_message wrote past buffer."); + if (length > maxLength) + throw internal_error("ProtocolExtension::build_bencode wrote past buffer."); return DataBuffer(b, b + length); } DataBuffer +ProtocolExtension::generate_toggle_message(MessageType t, bool on) { + // TODO: Check if we're accepting this message type? + + // Manually create bencoded map { "m" => { message_keys[t] => on ? t : 0 } } + return build_bencode(32, "d1:md%zu:%si%deee", strlen(ext_handshake_key_names[message_keys[t]].key) - 3, ext_handshake_key_names[message_keys[t]].key + 3, on ? t : 0); +} + +DataBuffer ProtocolExtension::generate_ut_pex_message(const PEXList& added, const PEXList& removed) { if (added.empty() && removed.empty()) return DataBuffer(); @@ -173,7 +257,7 @@ ProtocolExtension::generate_ut_pex_message(const PEXList& added, const PEXList& void ProtocolExtension::read_start(int type, uint32_t length, bool skip) { - if (is_default() || (type >= FIRST_INVALID) || length > (1 << 14)) + if (is_default() || (type >= FIRST_INVALID) || length > (1 << 15)) throw communication_error("Received invalid extension message."); if (m_read != NULL || length < 0) @@ -193,41 +277,42 @@ ProtocolExtension::read_start(int type, uint32_t length, bool skip) { m_readPos = m_read = new char[length]; } -void +bool ProtocolExtension::read_done() { - if (m_readType == SKIP_EXTENSION) { - delete [] m_read; - m_read = NULL; - return; - } + bool blocked = false; - std::stringstream s(std::string(m_read, m_readPos)); - s.imbue(std::locale::classic()); + try { + switch(m_readType) { + case SKIP_EXTENSION: + break; - delete [] m_read; - m_read = NULL; + case HANDSHAKE: + blocked = parse_handshake(); + break; - Object message; - s >> message; + case UT_PEX: + blocked = parse_ut_pex(); + break; - if (s.fail() || !message.is_map()) - throw communication_error("Invalid extension message."); + case UT_METADATA: + blocked = parse_ut_metadata(); + break; - switch(m_readType) { - case HANDSHAKE: - parse_handshake(message); - break; - - case UT_PEX: - parse_ut_pex(message); - break; + default: + throw internal_error("ProtocolExtension::read_done called with invalid extension type."); + } - default: - throw internal_error("ProtocolExtension::down_extension_finished called with invalid extension type."); + } catch (bencode_error& e) { + // Ignore malformed messages. } + delete [] m_read; + m_read = NULL; + m_readType = FIRST_INVALID; m_flags |= flag_received_ext; + + return !blocked; } // Called whenever peer enables or disables an extension. @@ -241,25 +326,23 @@ ProtocolExtension::peer_toggle_remote(int type, bool active) { } } -void -ProtocolExtension::parse_handshake(const Object& message) { - if (message.has_key_map("m")) { - const Object& idMap = message.get_key("m"); +bool +ProtocolExtension::parse_handshake() { + ExtHandshakeMessage message; + staticMap_read_bencode(m_read, m_readPos, message); - for (int t = HANDSHAKE + 1; t < FIRST_INVALID; t++) { - if (!idMap.has_key_value(message_keys[t])) - continue; + for (int t = HANDSHAKE + 1; t < FIRST_INVALID; t++) { + if (!message[message_keys[t]].is_value()) + continue; - uint8_t id = idMap.get_key_value(message_keys[t]); + uint8_t id = message[message_keys[t]].as_value(); - set_remote_supported(t); + set_remote_supported(t); - if (id != m_idMap[t - 1]) { - peer_toggle_remote(t, id != 0); - - m_idMap[t - 1] = id; - } + if (id != m_idMap[t - 1]) { + peer_toggle_remote(t, id != 0); + m_idMap[t - 1] = id; } } @@ -271,31 +354,39 @@ ProtocolExtension::parse_handshake(const Object& message) { unset_local_enabled(t); } - if (message.has_key_value("p")) { - uint16_t port = message.get_key_value("p"); + if (message[key_p].is_value()) { + uint16_t port = message[key_p].as_value(); if (port > 0) m_peerInfo->set_listen_port(port); } - if (message.has_key_value("reqq")) - m_maxQueueLength = message.get_key_value("reqq"); + if (message[key_reqq].is_value()) + m_maxQueueLength = message[key_reqq].as_value(); + + if (message[key_metadataSize].is_value()) + m_download->set_metadata_size(message[key_metadataSize].as_value()); m_flags &= ~flag_initial_handshake; + + return false; } -void -ProtocolExtension::parse_ut_pex(const Object& message) { +bool +ProtocolExtension::parse_ut_pex() { // Ignore message if we're still in the handshake (no connection // yet), or no peers are present. + ExtPEXMessage message; + staticMap_read_bencode(m_read, m_readPos, message); + // TODO: Check if pex is enabled? - if (!message.has_key_string("added")) - return; + if (!message[key_pex_added].is_sstring()) + return false; - const std::string& peers = message.get_key_string("added"); + SimpleString peers = message[key_pex_added].as_sstring(); if (peers.empty()) - return; + return false; AddressList l; l.parse_address_compact(peers); @@ -303,6 +394,82 @@ ProtocolExtension::parse_ut_pex(const Object& message) { l.erase(std::unique(l.begin(), l.end()), l.end()); m_download->peer_list()->insert_available(&l); + + return false; +} + +bool +ProtocolExtension::parse_ut_metadata() { + ExtMetadataMessage message; + + // Piece data comes after bencoded extension message. + const char* dataStart = staticMap_read_bencode(m_read, m_readPos, message); + + switch(message[key_msgType].as_value()) { + case 0: + // Can't process new request while still having data to send. + if (has_pending_message()) + return true; + + send_metadata_piece(message[key_piece].as_value()); + break; + + case 1: + if (m_connection == NULL) + break; + + m_connection->receive_metadata_piece(message[key_piece].as_value(), dataStart, m_readPos - dataStart); + break; + + case 2: + if (m_connection != NULL) + m_connection->receive_metadata_piece(message[key_piece].as_value(), NULL, 0); + break; + }; + + return false; +} + +void +ProtocolExtension::send_metadata_piece(size_t piece) { + // Reject out-of-range piece, or if we don't have the complete metadata yet. + size_t metadataSize = m_download->info()->metadata_size(); + size_t pieceEnd = (metadataSize + metadata_piece_size - 1) >> metadata_piece_shift; + + if (m_download->info()->is_meta_download() || piece >= pieceEnd) { + // reject: { "msg_type" => 2, "piece" => ... } + m_pendingType = UT_METADATA; + m_pending = build_bencode(40, "d8:msg_typei2e5:piecei%zuee", piece); + return; + } + + // These messages will be rare, so we'll just build the + // metadata here instead of caching it uselessly. + char* buffer = new char[metadataSize]; + object_buffer_t result = object_write_bencode_c(object_write_to_buffer, NULL, object_buffer_t(buffer, buffer + metadataSize), + &(*manager->download_manager()->find(m_download->info()))->bencode()->get_key("info")); + + // data: { "msg_type" => 1, "piece" => ..., "total_size" => ... } followed by piece data (outside of dictionary) + size_t length = piece == pieceEnd - 1 ? m_download->info()->metadata_size() % metadata_piece_size : metadata_piece_size; + m_pendingType = UT_METADATA; + m_pending = build_bencode(length + 128, "d8:msg_typei1e5:piecei%zue10:total_sizei%zuee", piece, metadataSize); + + memcpy(m_pending.end(), buffer + (piece << metadata_piece_shift), length); + m_pending.set(m_pending.data(), m_pending.end() + length, m_pending.owned()); + delete [] buffer; +} + +bool +ProtocolExtension::request_metadata_piece(const Piece* p) { + if (p->offset() % metadata_piece_size) + throw internal_error("ProtocolExtension::request_metadata_piece got misaligned piece offset."); + + if (has_pending_message()) + return false; + + m_pendingType = UT_METADATA; + m_pending = build_bencode(40, "d8:msg_typei0e5:piecei%uee", (unsigned)(p->offset() >> metadata_piece_shift)); + return true; } } diff --git a/src/protocol/extensions.h b/src/protocol/extensions.h index 1c370fc..485e7d7 100644 --- a/src/protocol/extensions.h +++ b/src/protocol/extensions.h @@ -46,6 +46,13 @@ #include "download/download_info.h" #include "net/data_buffer.h" +// Not really important, so no need to make this a configure check. +#ifdef __GNUC__ +#define ATTRIBUTE_PRINTF(num) __attribute__ ((format (printf, num, num+1))) +#else +#define ATTRIBUTE_PRINTF(num) +#endif + namespace torrent { class ProtocolExtension { @@ -53,6 +60,7 @@ public: typedef enum { HANDSHAKE = 0, UT_PEX, + UT_METADATA, FIRST_INVALID, // first invalid message ID @@ -71,11 +79,13 @@ public: static const int flag_local_enabled_base = 1<<8; static const int flag_remote_supported_base = 1<<16; - static const char* message_keys[FIRST_INVALID]; - // Number of extensions we support, not counting handshake. static const int extension_count = FIRST_INVALID - HANDSHAKE - 1; + // Fixed size of a metadata piece (16 KB). + static const size_t metadata_piece_shift = 14; + static const size_t metadata_piece_size = 1 << metadata_piece_shift; + ProtocolExtension(); ~ProtocolExtension() { delete [] m_read; } @@ -86,6 +96,7 @@ public: static ProtocolExtension make_default(); void set_info(PeerInfo* peerInfo, DownloadMain* download) { m_peerInfo = peerInfo; m_download = download; } + void set_connection(PeerConnectionBase* c) { m_connection = c; } DataBuffer generate_handshake_message(); static DataBuffer generate_toggle_message(MessageType t, bool on); @@ -107,7 +118,7 @@ public: // Handle reading extension data from peer. void read_start(int type, uint32_t length, bool skip); - void read_done(); + bool read_done(); char* read_position() { return m_readPos; } bool read_move(uint32_t v) { m_readPos += v; return (m_readLeft -= v) == 0; } @@ -127,11 +138,23 @@ public: void clear_initial_pex() { m_flags &= ~flag_initial_pex; } void reset() { std::memset(&m_idMap, 0, sizeof(m_idMap)); } + bool request_metadata_piece(const Piece* p); + + // To handle cases where the extension protocol needs to send a reply. + bool has_pending_message() const { return m_pendingType != HANDSHAKE; } + MessageType pending_message_type() const { return m_pendingType; } + DataBuffer pending_message_data() { return m_pending.release(); } + void clear_pending_message() { if (m_pending.empty()) m_pendingType = HANDSHAKE; } + private: - void parse_handshake(const Object& message); - void parse_ut_pex(const Object& message); + bool parse_handshake(); + bool parse_ut_pex(); + bool parse_ut_metadata(); + + static DataBuffer build_bencode(size_t maxLength, const char* format, ...) ATTRIBUTE_PRINTF(2); void peer_toggle_remote(int type, bool active); + void send_metadata_piece(size_t piece); // Map of IDs peer uses for each extension message type, excluding // HANDSHAKE. @@ -142,11 +165,15 @@ private: int m_flags; PeerInfo* m_peerInfo; DownloadMain* m_download; + PeerConnectionBase* m_connection; uint8_t m_readType; uint32_t m_readLeft; char* m_read; char* m_readPos; + + MessageType m_pendingType; + DataBuffer m_pending; }; inline @@ -156,10 +183,13 @@ ProtocolExtension::ProtocolExtension() : m_flags(flag_local_enabled_base | flag_remote_supported_base | flag_initial_handshake), m_peerInfo(NULL), m_download(NULL), + m_connection(NULL), m_readType(FIRST_INVALID), - m_read(NULL) { + m_read(NULL), + m_pendingType(HANDSHAKE) { reset(); + set_local_enabled(UT_METADATA); } inline ProtocolExtension diff --git a/src/protocol/handshake.cc b/src/protocol/handshake.cc index d863f7b..7fb389b 100644 --- a/src/protocol/handshake.cc +++ b/src/protocol/handshake.cc @@ -723,6 +723,17 @@ restart: case READ_MESSAGE: case POST_HANDSHAKE: + // For meta-downloads, we aren't interested in the bitfield or + // extension messages here, PCMetadata handles all that. The + // bitfield only refers to the single-chunk meta-data, so fake that. + if (m_download->info()->is_meta_download()) { + m_bitfield.set_size_bits(1); + m_bitfield.allocate(); + m_bitfield.set(0); + read_done(); + break; + } + fill_read_buffer(5); // Received a keep-alive message which means we won't be @@ -1022,6 +1033,10 @@ Handshake::prepare_peer_info() { std::memcpy(m_peerInfo->set_options(), m_options, 8); m_peerInfo->mutable_id().assign((const char*)m_readBuffer.position()); m_readBuffer.consume(20); + + // For meta downloads, we require support of the extension protocol. + if (m_download->info()->is_meta_download() && !m_peerInfo->supports_extensions()) + throw handshake_error(ConnectionManager::handshake_dropped, e_handshake_unwanted_connection); } void diff --git a/src/protocol/peer_connection_base.cc b/src/protocol/peer_connection_base.cc index ab043a6..815ea93 100644 --- a/src/protocol/peer_connection_base.cc +++ b/src/protocol/peer_connection_base.cc @@ -93,8 +93,7 @@ PeerConnectionBase::~PeerConnectionBase() { if (m_extensions != NULL && !m_extensions->is_default()) delete m_extensions; - if (m_extensionMessage.owned()) - m_extensionMessage.clear(); + m_extensionMessage.clear(); } void @@ -116,6 +115,8 @@ PeerConnectionBase::initialize(DownloadMain* download, PeerInfo* peerInfo, Socke m_encryption = *encryptionInfo; m_extensions = extensions; + m_extensions->set_connection(this); + m_peerChunks.set_peer_info(m_peerInfo); m_peerChunks.bitfield()->swap(*bitfield); @@ -581,8 +582,12 @@ PeerConnectionBase::down_extension() { m_extensions->read_move(bytes); } - if (m_extensions->is_complete()) - m_extensions->read_done(); + // If extension can't be processed yet (due to a pending write), + // disable reads until the pending message is completely sent. + if (m_extensions->is_complete() && !m_extensions->is_invalid() && !m_extensions->read_done()) { + manager->poll()->remove_read(this); + return false; + } return m_extensions->is_complete(); } @@ -693,12 +698,15 @@ PeerConnectionBase::up_extension() { if (m_extensionOffset < m_extensionMessage.length()) return false; - // clear() deletes the buffer, only do that if we made a copy, - // otherwise the buffer is shared among all connections. - if (m_extensionMessage.owned()) - m_extensionMessage.clear(); - else - m_extensionMessage.set(NULL, NULL, false); + m_extensionMessage.clear(); + + // If we have an unprocessed message, process it now and enable reads again. + if (m_extensions->is_complete() && !m_extensions->is_invalid()) { + if (!m_extensions->read_done()) + throw internal_error("PeerConnectionBase::up_extension could not process complete extension message."); + + manager->poll()->insert_read(this); + } return true; } @@ -857,4 +865,16 @@ PeerConnectionBase::send_pex_message() { return true; } +// Extension protocol needs to send a reply. +bool +PeerConnectionBase::send_ext_message() { + write_prepare_extension(m_extensions->pending_message_type(), m_extensions->pending_message_data()); + m_extensions->clear_pending_message(); + return true; +} + +void +PeerConnectionBase::receive_metadata_piece(uint32_t piece, const char* data, uint32_t length) { +} + } diff --git a/src/protocol/peer_connection_base.h b/src/protocol/peer_connection_base.h index 2994963..d131341 100644 --- a/src/protocol/peer_connection_base.h +++ b/src/protocol/peer_connection_base.h @@ -140,6 +140,9 @@ public: void read_insert_poll_safe(); void write_insert_poll_safe(); + // Communication with the protocol extensions + virtual void receive_metadata_piece(uint32_t piece, const char* data, uint32_t length); + protected: static const uint32_t extension_must_encrypt = ~uint32_t(); @@ -179,6 +182,7 @@ protected: bool try_request_pieces(); bool send_pex_message(); + bool send_ext_message(); DownloadMain* m_download; diff --git a/src/protocol/peer_connection_leech.cc b/src/protocol/peer_connection_leech.cc index a75d333..36c6d7a 100644 --- a/src/protocol/peer_connection_leech.cc +++ b/src/protocol/peer_connection_leech.cc @@ -333,9 +333,13 @@ PeerConnection::read_message() { m_down->set_state(ProtocolRead::READ_EXTENSION); } - if (down_extension()) - m_down->set_state(ProtocolRead::IDLE); + if (!down_extension()) + return false; + if (m_extensions->has_pending_message()) + write_insert_poll_safe(); + + m_down->set_state(ProtocolRead::IDLE); return true; default: @@ -433,6 +437,9 @@ PeerConnection::event_read() { if (!down_extension()) return; + if (m_extensions->has_pending_message()) + write_insert_poll_safe(); + m_down->set_state(ProtocolRead::IDLE); break; @@ -546,6 +553,10 @@ PeerConnection::fill_write_buffer() { send_pex_message()) { // Don't do anything else if send_pex_message() succeeded. + } else if (m_extensions->has_pending_message() && m_up->can_write_extension() && + send_ext_message()) { + // Same. + } else if (!m_upChoke.choked() && !m_peerChunks.upload_queue()->empty() && m_up->can_write_piece() && diff --git a/src/protocol/peer_connection_metadata.cc b/src/protocol/peer_connection_metadata.cc new file mode 100644 index 0000000..24f13ca --- /dev/null +++ b/src/protocol/peer_connection_metadata.cc @@ -0,0 +1,461 @@ +// libTorrent - BitTorrent library +// Copyright (C) 2005-2007, Jari Sundell +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +// In addition, as a special exception, the copyright holders give +// permission to link the code of portions of this program with the +// OpenSSL library under certain conditions as described in each +// individual source file, and distribute linked combinations +// including the two. +// +// You must obey the GNU General Public License in all respects for +// all of the code used other than OpenSSL. If you modify file(s) +// with this exception, you may extend this exception to your version +// of the file(s), but you are not obligated to do so. If you do not +// wish to do so, delete this exception statement from your version. +// If you delete this exception statement from all source files in the +// program, then also delete it here. +// +// Contact: Jari Sundell +// +// Skomakerveien 33 +// 3185 Skoppum, NORWAY + +#include "config.h" + +#include +#include + +#include "data/chunk_list_node.h" +#include "download/choke_manager.h" +#include "download/chunk_selector.h" +#include "download/chunk_statistics.h" +#include "download/download_info.h" +#include "download/download_main.h" +#include "torrent/dht_manager.h" +#include "torrent/peer/connection_list.h" +#include "torrent/peer/peer_info.h" + +#include "extensions.h" +#include "peer_connection_metadata.h" + +namespace torrent { + +PeerConnectionMetadata::~PeerConnectionMetadata() { +} + +void +PeerConnectionMetadata::initialize_custom() { +} + +void +PeerConnectionMetadata::update_interested() { +} + +bool +PeerConnectionMetadata::receive_keepalive() { + if (cachedTime - m_timeLastRead > rak::timer::from_seconds(240)) + return false; + + m_tryRequest = true; + + // There's no point in adding ourselves to the write poll if the + // buffer is full, as that will already have been taken care of. + if (m_up->get_state() == ProtocolWrite::IDLE && + m_up->can_write_keepalive()) { + + write_insert_poll_safe(); + + ProtocolBuffer<512>::iterator old_end = m_up->buffer()->end(); + m_up->write_keepalive(); + + if (is_encrypted()) + m_encryption.encrypt(old_end, m_up->buffer()->end() - old_end); + } + + return true; +} + +// We keep the message in the buffer if it is incomplete instead of +// keeping the state and remembering the read information. This +// shouldn't happen very often compared to full reads. +inline bool +PeerConnectionMetadata::read_message() { + ProtocolBuffer<512>* buf = m_down->buffer(); + + if (buf->remaining() < 4) + return false; + + // Remember the start of the message so we may reset it if we don't + // have the whole message. + ProtocolBuffer<512>::iterator beginning = buf->position(); + + uint32_t length = buf->read_32(); + + if (length == 0) { + // Keepalive message. + m_down->set_last_command(ProtocolBase::KEEP_ALIVE); + + return true; + + } else if (buf->remaining() < 1) { + buf->set_position_itr(beginning); + return false; + + } else if (length > (1 << 20)) { + throw communication_error("PeerConnection::read_message() got an invalid message length."); + } + + m_down->set_last_command((ProtocolBase::Protocol)buf->peek_8()); + + // Ignore most messages, they aren't relevant for a metadata download. + switch (buf->read_8()) { + case ProtocolBase::CHOKE: + case ProtocolBase::UNCHOKE: + case ProtocolBase::INTERESTED: + case ProtocolBase::NOT_INTERESTED: + return true; + + case ProtocolBase::HAVE: + if (!m_down->can_read_have_body()) + break; + + buf->read_32(); + return true; + + case ProtocolBase::REQUEST: + if (!m_down->can_read_request_body()) + break; + + m_down->read_request(); + return true; + + case ProtocolBase::PIECE: + throw communication_error("Received a piece but the connection is strictly for meta data."); + + case ProtocolBase::CANCEL: + if (!m_down->can_read_cancel_body()) + break; + + m_down->read_request(); + return true; + + case ProtocolBase::PORT: + if (!m_down->can_read_port_body()) + break; + + manager->dht_manager()->add_node(m_peerInfo->socket_address(), m_down->buffer()->read_16()); + return true; + + case ProtocolBase::EXTENSION_PROTOCOL: + if (!m_down->can_read_extension_body()) + break; + + if (m_extensions->is_default()) { + m_extensions = new ProtocolExtension(); + m_extensions->set_info(m_peerInfo, m_download); + } + + { + int extension = m_down->buffer()->read_8(); + m_extensions->read_start(extension, length - 2, (extension == ProtocolExtension::UT_PEX) && !m_download->want_pex_msg()); + m_down->set_state(ProtocolRead::READ_EXTENSION); + } + + if (!down_extension()) + return false; + + // Drop peer if it disabled the metadata extension. + if (!m_extensions->is_remote_supported(ProtocolExtension::UT_METADATA)) + throw close_connection(); + + m_down->set_state(ProtocolRead::IDLE); + m_tryRequest = true; + write_insert_poll_safe(); + + return true; + + case ProtocolBase::BITFIELD: + // Discard the bitfield sent by the peer. + m_skipLength = length - 1; + m_down->set_state(ProtocolRead::READ_SKIP_PIECE); + return false; + + default: + throw communication_error("Received unsupported message type."); + } + + // We were unsuccessfull in reading the message, need more data. + buf->set_position_itr(beginning); + return false; +} + +void +PeerConnectionMetadata::event_read() { + m_timeLastRead = cachedTime; + + // Need to make sure ProtocolBuffer::end() is pointing to the end of + // the unread data, and that the unread data starts from the + // beginning of the buffer. Or do we use position? Propably best, + // therefor ProtocolBuffer::position() points to the beginning of + // the unused data. + + try { + + // Normal read. + // + // We rarely will read zero bytes as the read of 64 bytes will + // almost always either not fill up or it will require additional + // reads. + // + // Only loop when end hits 64. + + do { + switch (m_down->get_state()) { + case ProtocolRead::IDLE: + if (m_down->buffer()->size_end() < read_size) { + unsigned int length = read_stream_throws(m_down->buffer()->end(), read_size - m_down->buffer()->size_end()); + m_down->throttle()->node_used_unthrottled(length); + + if (is_encrypted()) + m_encryption.decrypt(m_down->buffer()->end(), length); + + m_down->buffer()->move_end(length); + } + + while (read_message()); + + if (m_down->buffer()->size_end() == read_size) { + m_down->buffer()->move_unused(); + break; + } else { + m_down->buffer()->move_unused(); + return; + } + + case ProtocolRead::READ_EXTENSION: + if (!down_extension()) + return; + + // Drop peer if it disabled the metadata extension. + if (!m_extensions->is_remote_supported(ProtocolExtension::UT_METADATA)) + throw close_connection(); + + m_down->set_state(ProtocolRead::IDLE); + m_tryRequest = true; + write_insert_poll_safe(); + break; + + // Actually skipping the bitfield. + // We never receive normal piece messages anyway. + case ProtocolRead::READ_SKIP_PIECE: + if (!read_skip_bitfield()) + return; + + m_down->set_state(ProtocolRead::IDLE); + break; + + default: + throw internal_error("PeerConnection::event_read() wrong state."); + } + + // Figure out how to get rid of the shouldLoop boolean. + } while (true); + + // Exception handlers: + + } catch (close_connection& e) { + m_download->connection_list()->erase(this, 0); + + } catch (blocked_connection& e) { + m_download->info()->signal_network_log().emit("Momentarily blocked read connection."); + m_download->connection_list()->erase(this, 0); + + } catch (network_error& e) { + m_download->connection_list()->erase(this, 0); + + } catch (storage_error& e) { + m_download->info()->signal_storage_error().emit(e.what()); + m_download->connection_list()->erase(this, 0); + + } catch (base_error& e) { + std::stringstream s; + s << "Connection read fd(" << get_fd().get_fd() << ',' << m_down->get_state() << ',' << m_down->last_command() << ") \"" << e.what() << '"'; + + throw internal_error(s.str()); + } +} + +inline void +PeerConnectionMetadata::fill_write_buffer() { + ProtocolBuffer<512>::iterator old_end = m_up->buffer()->end(); + + if (m_tryRequest) + m_tryRequest = try_request_metadata_pieces(); + + if (m_sendPEXMask && m_up->can_write_extension() && + send_pex_message()) { + // Don't do anything else if send_pex_message() succeeded. + + } else if (m_extensions->has_pending_message() && m_up->can_write_extension() && + send_ext_message()) { + // Same. + } + + if (is_encrypted()) + m_encryption.encrypt(old_end, m_up->buffer()->end() - old_end); +} + +void +PeerConnectionMetadata::event_write() { + try { + + do { + + switch (m_up->get_state()) { + case ProtocolWrite::IDLE: + + fill_write_buffer(); + + if (m_up->buffer()->remaining() == 0) { + manager->poll()->remove_write(this); + return; + } + + m_up->set_state(ProtocolWrite::MSG); + + case ProtocolWrite::MSG: + if (!m_up->buffer()->consume(m_up->throttle()->node_used_unthrottled(write_stream_throws(m_up->buffer()->position(), m_up->buffer()->remaining())))) + return; + + m_up->buffer()->reset(); + + if (m_up->last_command() != ProtocolBase::EXTENSION_PROTOCOL) { + m_up->set_state(ProtocolWrite::IDLE); + break; + } + + m_up->set_state(ProtocolWrite::WRITE_EXTENSION); + + case ProtocolWrite::WRITE_EXTENSION: + if (!up_extension()) + return; + + m_up->set_state(ProtocolWrite::IDLE); + break; + + default: + throw internal_error("PeerConnection::event_write() wrong state."); + } + + } while (true); + + } catch (close_connection& e) { + m_download->connection_list()->erase(this, 0); + + } catch (blocked_connection& e) { + m_download->info()->signal_network_log().emit("Momentarily blocked write connection."); + m_download->connection_list()->erase(this, 0); + + } catch (network_error& e) { + m_download->connection_list()->erase(this, 0); + + } catch (storage_error& e) { + m_download->info()->signal_storage_error().emit(e.what()); + m_download->connection_list()->erase(this, 0); + + } catch (base_error& e) { + std::stringstream s; + s << "Connection write fd(" << get_fd().get_fd() << ',' << m_up->get_state() << ',' << m_up->last_command() << ") \"" << e.what() << '"'; + + throw internal_error(s.str()); + } +} + +bool +PeerConnectionMetadata::read_skip_bitfield() { + if (m_down->buffer()->remaining()) { + uint32_t length = std::min(m_skipLength, (uint32_t)m_down->buffer()->remaining()); + m_down->buffer()->consume(length); + m_skipLength -= length; + } + + if (m_skipLength) { + uint32_t length = std::min(m_skipLength, (uint32_t)null_buffer_size); + length = read_stream_throws(m_nullBuffer, length); + if (!length) + return false; + m_skipLength -= length; + } + + return !m_skipLength; +} + +// Same as the PCB code, but only one at a time and with the extension protocol. +bool +PeerConnectionMetadata::try_request_metadata_pieces() { + if (m_download->file_list()->chunk_size() == 1 || !m_extensions->is_remote_supported(ProtocolExtension::UT_METADATA)) + return false; + + if (download_queue()->queued_empty()) + m_downStall = 0; + + uint32_t pipeSize = download_queue()->calculate_pipe_size(m_peerChunks.download_throttle()->rate()->rate()); + + // Don't start requesting if we can't do it in large enough chunks. + if (download_queue()->queued_size() >= (pipeSize + 10) / 2) + return false; + + if (!download_queue()->queued_size() < pipeSize || !m_up->can_write_extension() || + m_extensions->has_pending_message()) + return false; + + const Piece* p = download_queue()->delegate(); + + if (p == NULL) + return false; + + if (!m_download->file_list()->is_valid_piece(*p) || !m_peerChunks.bitfield()->get(p->index())) + throw internal_error("PeerConnectionMetadata::try_request_metadata_pieces() tried to use an invalid piece."); + + return m_extensions->request_metadata_piece(p); +} + +void +PeerConnectionMetadata::receive_metadata_piece(uint32_t piece, const char* data, uint32_t length) { + if (data == NULL) { + // Length is not set in a reject message. + length = ProtocolExtension::metadata_piece_size; + if ((piece << ProtocolExtension::metadata_piece_shift) + ProtocolExtension::metadata_piece_size >= m_download->file_list()->size_bytes()) + length = m_download->file_list()->chunk_size() % ProtocolExtension::metadata_piece_size; + m_tryRequest = false; + read_cancel_piece(Piece(0, piece << ProtocolExtension::metadata_piece_shift, length)); + return; + } + + if (!down_chunk_start(Piece(0, piece << ProtocolExtension::metadata_piece_shift, length))) + down_chunk_skip_process(data, length); + else + down_chunk_process(data, length); + + if (!m_downloadQueue.transfer()->is_finished()) + throw internal_error("PeerConnectionMetadata::receive_metadata_piece did not have complete piece."); + + m_tryRequest = true; + down_chunk_finished(); +} + +} diff --git a/src/protocol/peer_connection_metadata.h b/src/protocol/peer_connection_metadata.h new file mode 100644 index 0000000..127700a --- /dev/null +++ b/src/protocol/peer_connection_metadata.h @@ -0,0 +1,73 @@ +// libTorrent - BitTorrent library +// Copyright (C) 2005-2007, Jari Sundell +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +// In addition, as a special exception, the copyright holders give +// permission to link the code of portions of this program with the +// OpenSSL library under certain conditions as described in each +// individual source file, and distribute linked combinations +// including the two. +// +// You must obey the GNU General Public License in all respects for +// all of the code used other than OpenSSL. If you modify file(s) +// with this exception, you may extend this exception to your version +// of the file(s), but you are not obligated to do so. If you do not +// wish to do so, delete this exception statement from your version. +// If you delete this exception statement from all source files in the +// program, then also delete it here. +// +// Contact: Jari Sundell +// +// Skomakerveien 33 +// 3185 Skoppum, NORWAY + +#ifndef LIBTORRENT_PROTOCOL_PEER_CONNECTION_METADATA_H +#define LIBTORRENT_PROTOCOL_PEER_CONNECTION_METADATA_H + +#include "peer_connection_base.h" + +#include "torrent/download.h" + +namespace torrent { + +class PeerConnectionMetadata : public PeerConnectionBase { +public: + ~PeerConnectionMetadata(); + + virtual void initialize_custom(); + virtual void update_interested(); + virtual bool receive_keepalive(); + + virtual void event_read(); + virtual void event_write(); + + virtual void receive_metadata_piece(uint32_t piece, const char* data, uint32_t length); + +private: + inline bool read_message(); + + bool read_skip_bitfield(); + + bool try_request_metadata_pieces(); + + inline void fill_write_buffer(); + + uint32_t m_skipLength; +}; + +} + +#endif diff --git a/src/protocol/peer_factory.cc b/src/protocol/peer_factory.cc index 7ab9fe8..cfe6a1e 100644 --- a/src/protocol/peer_factory.cc +++ b/src/protocol/peer_factory.cc @@ -38,6 +38,7 @@ #include "peer_factory.h" #include "peer_connection_leech.h" +#include "peer_connection_metadata.h" namespace torrent { @@ -62,4 +63,11 @@ createPeerConnectionInitialSeed(bool encrypted) { return pc; } +PeerConnectionBase* +createPeerConnectionMetadata(bool encrypted) { + PeerConnectionBase* pc = new PeerConnectionMetadata; + + return pc; +} + } diff --git a/src/protocol/peer_factory.h b/src/protocol/peer_factory.h index 363a5c3..f22d76f 100644 --- a/src/protocol/peer_factory.h +++ b/src/protocol/peer_factory.h @@ -44,6 +44,7 @@ class PeerConnectionBase; PeerConnectionBase* createPeerConnectionDefault(bool encrypted); PeerConnectionBase* createPeerConnectionSeed(bool encrypted); PeerConnectionBase* createPeerConnectionInitialSeed(bool encrypted); +PeerConnectionBase* createPeerConnectionMetadata(bool encrypted); } diff --git a/src/torrent/Makefile.am b/src/torrent/Makefile.am index bec124d..820ce52 100644 --- a/src/torrent/Makefile.am +++ b/src/torrent/Makefile.am @@ -41,6 +41,9 @@ libsub_torrent_la_SOURCES = \ rate.h \ resume.cc \ resume.h \ + simple_string.h \ + static_map.cc \ + static_map.h \ throttle.cc \ throttle.h \ torrent.cc \ @@ -74,6 +77,8 @@ libtorrentinclude_HEADERS = \ poll_select.h \ rate.h \ resume.h \ + simple_string.h \ + static_map.h \ throttle.h \ torrent.h \ tracker.h \ diff --git a/src/torrent/data/file_list.cc b/src/torrent/data/file_list.cc index 2f5d8d2..7208612 100644 --- a/src/torrent/data/file_list.cc +++ b/src/torrent/data/file_list.cc @@ -466,6 +466,18 @@ FileList::open(int flags) { m_isOpen = true; m_frozenRootDir = m_rootDir; + + // For meta-downloads, if the file exists, we have to assume that + // it is either 0 or 1 length or the correct size. If the size + // turns out wrong later, a communication_error will be thrown elsewhere + // to alert the user in this (unlikely) case. + if (size_bytes() < 2) { + rak::file_stat stat; + + // This probably recurses into open() once, but that is harmless. + if (stat.update((*begin())->frozen_path()) && stat.size() > 1) + return reset_filesize(stat.size()); + } } void @@ -661,4 +673,14 @@ FileList::update_completed() { } } +void +FileList::reset_filesize(int64_t size) { + close(); + m_chunkSize = size; + m_torrentSize = size; + (*begin())->set_size_bytes(size); + (*begin())->set_range(m_chunkSize); + open(open_no_create); +} + } diff --git a/src/torrent/data/file_list.h b/src/torrent/data/file_list.h index bcc8939..60d418a 100644 --- a/src/torrent/data/file_list.h +++ b/src/torrent/data/file_list.h @@ -167,6 +167,10 @@ protected: iterator inc_completed(iterator firstItr, uint32_t index) LIBTORRENT_NO_EXPORT; void update_completed() LIBTORRENT_NO_EXPORT; + // Used for meta downloads; we only know the + // size after the first extension handshake. + void reset_filesize(int64_t) LIBTORRENT_NO_EXPORT; + private: bool open_file(File* node, const Path& lastPath, int flags) LIBTORRENT_NO_EXPORT; void make_directory(Path::const_iterator pathBegin, Path::const_iterator pathEnd, Path::const_iterator startItr) LIBTORRENT_NO_EXPORT; diff --git a/src/torrent/download.cc b/src/torrent/download.cc index d6cc199..49daad9 100644 --- a/src/torrent/download.cc +++ b/src/torrent/download.cc @@ -225,6 +225,11 @@ Download::set_pex_enabled(bool enabled) { m_ptr->info()->set_pex_enabled(enabled); } +bool +Download::is_meta_download() const { + return m_ptr->info()->is_meta_download(); +} + const std::string& Download::name() const { if (m_ptr == NULL) @@ -504,6 +509,11 @@ Download::connection_type() const { void Download::set_connection_type(ConnectionType t) { + if (m_ptr->info()->is_meta_download()) { + m_ptr->main()->connection_list()->slot_new_connection(&createPeerConnectionMetadata); + return; + } + switch (t) { case CONNECTION_LEECH: m_ptr->main()->connection_list()->slot_new_connection(&createPeerConnectionDefault); diff --git a/src/torrent/download.h b/src/torrent/download.h index 5e9700e..5d16d4e 100644 --- a/src/torrent/download.h +++ b/src/torrent/download.h @@ -100,6 +100,8 @@ public: bool is_pex_enabled() const; void set_pex_enabled(bool enabled); + bool is_meta_download() const; + // Returns "" if the object is not valid. const std::string& name() const; @@ -184,6 +186,7 @@ public: CONNECTION_LEECH, CONNECTION_SEED, CONNECTION_INITIAL_SEED, + CONNECTION_METADATA, } ConnectionType; ConnectionType connection_type() const; diff --git a/src/torrent/hash_string.h b/src/torrent/hash_string.h index f62d450..14623f7 100644 --- a/src/torrent/hash_string.h +++ b/src/torrent/hash_string.h @@ -44,6 +44,7 @@ #include #include #include +#include namespace torrent { @@ -85,6 +86,8 @@ public: std::string str() const { return std::string(m_data, size_data); } + SimpleString s_str() const { return SimpleString(m_data, size_data); } + void clear(int v = 0) { std::memset(data(), v, size()); } void assign(const value_type* src) { std::memcpy(data(), src, size()); } @@ -96,6 +99,7 @@ public: // size_data. static const HashString* cast_from(const char* src) { return (const HashString*)src; } static const HashString* cast_from(const std::string& src) { return (const HashString*)src.c_str(); } + static const HashString* cast_from(const SimpleString& src){ return (const HashString*)src.c_str(); } static HashString* cast_from(char* src) { return (HashString*)src; } diff --git a/src/torrent/object.cc b/src/torrent/object.cc index 2b1cf41..b609f9c 100644 --- a/src/torrent/object.cc +++ b/src/torrent/object.cc @@ -44,47 +44,59 @@ namespace torrent { -Object& -Object::get_key(const std::string& k) { - check_throw(TYPE_MAP); - map_type::iterator itr = m_map->find(k); +std::pair +Object::map_type::insert(const value_type& value) { + base_type::iterator itr = lower_bound(value.first); - if (itr == m_map->end()) - throw bencode_error("Object operator [" + k + "] could not find element"); + if (itr != end() && !key_comp()(value.first, itr->first)) + return std::make_pair(itr, false); - return itr->second; -} + // Insert with an allocated copy of the key. + itr = base_type::insert(itr, value_type(value.first.copy(), value.second)); + // This means the value was actually already present. + if (itr->second.get_string() != NULL) + throw internal_error("Object::map_type::insert failed to insert value."); -const Object& -Object::get_key(const std::string& k) const { - check_throw(TYPE_MAP); - map_type::const_iterator itr = m_map->find(k); + // Make entry own the string and free it when erased. + itr->second.set_string(itr->first.c_str()); - if (itr == m_map->end()) - throw bencode_error("Object operator [" + k + "] could not find element"); + return std::make_pair(itr, true); +} - return itr->second; +Object::map_type::base_type::iterator +Object::map_type::insert(base_type::iterator itr, const value_type& value) { + SimpleString copy = value.first.copy(); + itr = base_type::insert(itr, value_type(copy, value.second)); + + // If the entry already owns its string, it wasn't really + // inserted and already existed, so discard the copy. + if (itr->second.get_string() != NULL) + delete [] copy.c_str(); + else + itr->second.set_string(itr->first.c_str()); + + return itr; } Object& -Object::get_key(const char* k) { +Object::get_key(const key_type& k) { check_throw(TYPE_MAP); - map_type::iterator itr = m_map->find(std::string(k)); + map_type::iterator itr = m_map->find(k); if (itr == m_map->end()) - throw bencode_error("Object operator [" + std::string(k) + "] could not find element"); + throw bencode_error("Object operator [" + k.str() + "] could not find element"); return itr->second; } const Object& -Object::get_key(const char* k) const { +Object::get_key(const key_type& k) const { check_throw(TYPE_MAP); - map_type::iterator itr = m_map->find(std::string(k)); + map_type::iterator itr = m_map->find(k); if (itr == m_map->end()) - throw bencode_error("Object operator [" + std::string(k) + "] could not find element"); + throw bencode_error("Object operator [" + k.str() + "] could not find element"); return itr->second; } @@ -143,7 +155,7 @@ Object::merge_copy(const Object& object, uint32_t maxDepth) { while (srcItr != srcLast) { destItr = std::find_if(destItr, dest.end(), rak::less_equal(srcItr->first, rak::mem_ref(&map_type::value_type::first))); - if (srcItr->first < destItr->first) + if (dest.key_comp()(srcItr->first, destItr->first)) // destItr remains valid and pointing to the next possible // position. dest.insert(destItr, *srcItr); @@ -195,6 +207,7 @@ Object::operator = (const Object& src) { case TYPE_STRING: m_string = new string_type(*src.m_string); break; case TYPE_LIST: m_list = new list_type(*src.m_list); break; case TYPE_MAP: m_map = new map_type(*src.m_map); break; + case TYPE_SSTRING:m_sstring = src.m_sstring; break; } return *this; diff --git a/src/torrent/object.h b/src/torrent/object.h index 7ad040b..b7b4e8f 100644 --- a/src/torrent/object.h +++ b/src/torrent/object.h @@ -42,21 +42,56 @@ #include #include #include +#include namespace torrent { -// TODO: Look into making a custom comp and allocator classes for the -// map_type which use a const char* for key_type. -// // TODO: Use placement new/delete in order to avoid the extra level of // indirection caused by the union. class LIBTORRENT_EXPORT Object { + template + class string_wrapper : public T { + public: + string_wrapper() : T(), m_string(NULL) {} + string_wrapper(const T& value) : T(value), m_string(NULL) {} + string_wrapper(const string_wrapper& other) : T(other), m_string(NULL) {} + + ~string_wrapper() { delete [] m_string; m_string = NULL; } + + const char* get_string() const { return m_string; } + void set_string(const char* s) { m_string = s; } + + private: + string_wrapper& operator = (const string_wrapper& other); + + const char* m_string; + }; + public: typedef int64_t value_type; typedef std::string string_type; typedef std::list list_type; - typedef std::map map_type; + class map_type : public std::map > { + public: + typedef std::map > base_type; + using base_type::value_type; + using base_type::key_type; + + map_type(const map_type& other) : base_type(other.key_comp()) { insert(other.begin(), other.end()); } + map_type() {} + + std::pair insert(const value_type& value); + base_type::iterator insert(base_type::iterator itr, const value_type& value); + + template + void insert(InputIterator begin, InputIterator end); + + Object& operator[] (key_type key); + + private: + map_type& operator = (const map_type& other); + }; typedef map_type::key_type key_type; typedef list_type::iterator list_iterator; @@ -82,13 +117,16 @@ public: TYPE_VALUE, TYPE_STRING, TYPE_LIST, - TYPE_MAP + TYPE_MAP, + TYPE_SSTRING, // Only used in StaticMap. }; Object() : m_flags(TYPE_NONE) {} Object(const value_type v) : m_flags(TYPE_VALUE), m_value(v) {} Object(const char* s) : m_flags(TYPE_STRING), m_string(new string_type(s)) {} Object(const string_type& s) : m_flags(TYPE_STRING), m_string(new string_type(s)) {} + Object(const char* s, size_t l) : m_flags(TYPE_SSTRING), m_sstring(SimpleString(s, l)) {} + Object(SimpleString s) : m_flags(TYPE_SSTRING), m_sstring(s) {} Object(const Object& b); ~Object() { clear(); } @@ -96,6 +134,7 @@ public: // Move this out of the class namespace, call them create_object_. static Object create_value() { return Object(value_type()); } static Object create_string() { return Object(string_type()); } + static Object create_sstring(){ return Object(SimpleString()); } static Object create_list() { Object tmp; tmp.m_flags = TYPE_LIST; tmp.m_list = new list_type(); return tmp; } static Object create_map() { Object tmp; tmp.m_flags = TYPE_MAP; tmp.m_map = new map_type(); return tmp; } @@ -120,6 +159,7 @@ public: bool is_string() const { return type() == TYPE_STRING; } bool is_list() const { return type() == TYPE_LIST; } bool is_map() const { return type() == TYPE_MAP; } + bool is_sstring() const { return type() == TYPE_SSTRING; } value_type& as_value() { check_throw(TYPE_VALUE); return m_value; } const value_type& as_value() const { check_throw(TYPE_VALUE); return m_value; } @@ -133,6 +173,9 @@ public: map_type& as_map() { check_throw(TYPE_MAP); return *m_map; } const map_type& as_map() const { check_throw(TYPE_MAP); return *m_map; } + SimpleStringBase& as_sstring() { check_throw(TYPE_SSTRING); return m_sstring; } + SimpleString as_sstring() const { check_throw(TYPE_SSTRING); return m_sstring; } + bool has_key(const key_type& k) const { check_throw(TYPE_MAP); return m_map->find(k) != m_map->end(); } bool has_key_value(const key_type& k) const { check_throw(TYPE_MAP); return check(m_map->find(k), TYPE_VALUE); } bool has_key_string(const key_type& k) const { check_throw(TYPE_MAP); return check(m_map->find(k), TYPE_STRING); } @@ -144,8 +187,6 @@ public: Object& get_key(const key_type& k); const Object& get_key(const key_type& k) const; - Object& get_key(const char* k); - const Object& get_key(const char* k) const; template value_type& get_key_value(const T& k) { return get_key(k).as_value(); } template const value_type& get_key_value(const T& k) const { return get_key(k).as_value(); } @@ -200,9 +241,31 @@ public: string_type* m_string; list_type* m_list; map_type* m_map; + SimpleStringBase m_sstring; }; }; +// We need to call our own insert function, so +// we have to define this operator to use that. +inline Object& +Object::map_type::operator[] (key_type key) { + base_type::iterator itr = lower_bound(key); + + if (itr == end() || key_comp()(key, itr->first)) + itr = insert(itr, value_type(key, mapped_type())); + + return itr->second; +} + +template +inline void +Object::map_type::insert(InputIterator itr, InputIterator itrEnd) { + while (itr != itrEnd) { + insert(end(), *itr); + ++itr; + } +} + inline Object::Object(const Object& b) : m_flags(b.type()) { switch (type()) { @@ -211,6 +274,7 @@ Object::Object(const Object& b) : m_flags(b.type()) { case TYPE_STRING: m_string = new string_type(*b.m_string); break; case TYPE_LIST: m_list = new list_type(*b.m_list); break; case TYPE_MAP: m_map = new map_type(*b.m_map); break; + case TYPE_SSTRING:m_sstring = b.m_sstring; break; } } @@ -222,6 +286,7 @@ Object::clear() { case TYPE_STRING: delete m_string; break; case TYPE_LIST: delete m_list; break; case TYPE_MAP: delete m_map; break; + case TYPE_SSTRING:break; } // Only clear type? diff --git a/src/torrent/object_stream.cc b/src/torrent/object_stream.cc index 18eb849..9d9a962 100644 --- a/src/torrent/object_stream.cc +++ b/src/torrent/object_stream.cc @@ -38,12 +38,14 @@ #include #include +#include #include #include "utils/sha1.h" #include "object.h" #include "object_stream.h" +#include "static_map.h" namespace torrent { @@ -63,6 +65,18 @@ object_read_string(std::istream* input, std::string& str) { return !input->fail(); } +Object +object_get_sstring(const char** buffer) { + /*const*/ char* next; + size_t length = strtoumax(*buffer, &next, 10); + + if (next == *buffer || *next != ':') + return Object(); + + *buffer = next + 1 + length; + return Object(next + 1, length); +} + // Could consider making this non-recursive, but they seldomly are // deep enough to make that worth-while. void @@ -159,6 +173,133 @@ object_read_bencode(std::istream* input, Object* object, uint32_t depth) { object->clear(); } +const char* +staticMap_read_bencode_c(const char* buffer, const char* bufferEnd, uint32_t depth, Object* values, const StaticMapKeys& keys, bool discard) { + if (buffer >= bufferEnd) + return bufferEnd; + + // Undecoded bencode object. + if (!discard && keys.type() == StaticMapKeys::TYPE_BENCODE) { + const char* begin = buffer; + buffer = staticMap_read_bencode_c(buffer, bufferEnd, ++depth, values, keys, true); + values[keys.index_begin()] = SimpleString(begin, buffer - begin); + return buffer; + } + + if (!discard && keys.type() == StaticMapKeys::TYPE_BENCODE_LIST && *buffer != 'l') + discard = true; + + switch (*buffer) { + case 'i': { + char* next; + intmax_t value = strtoimax(++buffer, &next, 10); + + if (next == buffer || next > bufferEnd || *next != 'e') + break; + + if (!discard && keys.type() == StaticMapKeys::TYPE_VALUE) + values[keys.index_begin()] = (int64_t)value; + + return next + 1; + } + + case 'l': { + ++buffer; + if (++depth >= 1024) + break; + + // Want undecoded bencode list: find end of list. + if (!discard && keys.type() == StaticMapKeys::TYPE_BENCODE_LIST) { + const char* end = buffer; + while (end < bufferEnd && *end != 'e') + end = staticMap_read_bencode_c(end, bufferEnd, depth, values, keys, true); + + values[keys.index_begin()] = SimpleString(buffer, end - buffer); + return ++end; + } + + StaticMapKeys::const_iterator itr = keys.begin(); + while (buffer != bufferEnd) { + if (*buffer == 'e') + return ++buffer; + + discard |= itr == keys.end(); + buffer = staticMap_read_bencode_c(buffer, bufferEnd, depth, values, discard ? keys : *itr, discard); + + if (itr != keys.end()) + ++itr; + } + + break; + } + + case 'd': { + ++buffer; + if (++depth >= 1024) + break; + + StaticMapKeys::const_iterator itr = keys.begin(); + SimpleString last; + bool discardThis = discard; + + while (buffer != bufferEnd) { + if (*buffer == 'e') + return ++buffer; + + Object keyObj = object_get_sstring(&buffer); + if (!keyObj.is_sstring()) + break; + + SimpleString key = keyObj.as_sstring(); + if (key.end() >= bufferEnd) + break; + + if (key < last) { + itr = keys.begin(); + discardThis = discard; + } + + discardThis |= itr == keys.end(); + int cmp = discardThis ? -1 : key.cmp(itr->key()); + while (cmp > 0) { + if (++itr == keys.end()) { + cmp = -1; + discardThis = true; + break; + } + + cmp = key.cmp(itr->key()); + } + + buffer = staticMap_read_bencode_c(buffer, bufferEnd, depth, values, cmp ? keys : *itr, cmp); + + last = key; + } + + break; + } + + default: + if (*buffer < '0' || *buffer > '9') + break; + + Object strObj = object_get_sstring(&buffer); + if (!strObj.is_sstring()) + break; + + SimpleString str = strObj.as_sstring(); + if (str.end() >= bufferEnd) + break; + + if (!discard && keys.type() == StaticMapKeys::TYPE_VALUE) + values[keys.index_begin()] = str; + + return str.end(); + } + + throw bencode_error("Invalid bencode data."); +} + void object_write_bencode(std::ostream* output, const Object* object) { char buffer[1024]; @@ -267,6 +408,7 @@ void object_write_bencode_c_object(object_write_data_t* output, const Object* object) { switch (object->type()) { case Object::TYPE_NONE: + case Object::TYPE_SSTRING: break; case Object::TYPE_VALUE: @@ -306,6 +448,86 @@ object_write_bencode_c_object(object_write_data_t* output, const Object* object) } } +void +staticMap_write_bencode_c_values(object_write_data_t* output, const Object* values, const StaticMapKeys& keys) { + if (keys.type() == StaticMapKeys::TYPE_LIST) { + size_t indexEnd = keys.index_begin(); + while (indexEnd < keys.index_end() && values[indexEnd].type() != Object::TYPE_NONE) + indexEnd++; + + // Empty list? Drop it. Sparse lists are not possible so only check first element. + if (indexEnd == keys.index_begin()) + return; + + object_write_bencode_c_char(output, 'l'); + StaticMapKeys::const_iterator itr = keys.begin(); + size_t index = keys.index_begin(); + while (index < indexEnd) { + staticMap_write_bencode_c_values(output, values, *itr); + index = itr->index_end(); + if (++itr == keys.end() && index != indexEnd) + throw internal_error("staticMap_write_bencode_c_values reached end of list before end of index list."); + } + object_write_bencode_c_char(output, 'e'); + + } else if (keys.type() == StaticMapKeys::TYPE_DICT) { + // Find next non-empty entry. + size_t next = keys.index_begin(); + while (values[next].type() == Object::TYPE_NONE) + if (++next == keys.index_end()) + return; + + object_write_bencode_c_char(output, 'd'); + StaticMapKeys::const_iterator itr = keys.begin(); + while (next < keys.index_end()) { + while (itr->index_end() <= next) + if (++itr == keys.end()) + throw internal_error("staticMap_write_bencode_c_values reached end of keys before end of index list."); + + object_write_bencode_c_value(output, itr->key().size()); + object_write_bencode_c_char(output, ':'); + object_write_bencode_c_string(output, itr->key().c_str(), itr->key().size()); + + staticMap_write_bencode_c_values(output, values, *itr); + + next = itr->index_end(); + while (next < keys.index_end() && values[next].type() == Object::TYPE_NONE) + ++next; + } + object_write_bencode_c_char(output, 'e'); + + // Undecoded bencode value. + } else if (keys.type() == StaticMapKeys::TYPE_BENCODE) { + SimpleString value = values[keys.index_begin()].as_sstring(); + object_write_bencode_c_string(output, value.c_str(), value.size()); + + } else if (keys.type() == StaticMapKeys::TYPE_BENCODE_LIST) { + SimpleString value = values[keys.index_begin()].as_sstring(); + object_write_bencode_c_char(output, 'l'); + object_write_bencode_c_string(output, value.c_str(), value.size()); + object_write_bencode_c_char(output, 'e'); + + } else if (keys.type() != StaticMapKeys::TYPE_VALUE) { + throw internal_error("staticMap_write_bencode_c_values received key keys with invalid values type."); + + } else if (values[keys.index_begin()].type() == Object::TYPE_NONE) { + + } else if (values[keys.index_begin()].type() == Object::TYPE_VALUE) { + object_write_bencode_c_char(output, 'i'); + object_write_bencode_c_value(output, values[keys.index_begin()].as_value()); + object_write_bencode_c_char(output, 'e'); + + } else if (values[keys.index_begin()].type() == Object::TYPE_SSTRING) { + SimpleString value = values[keys.index_begin()].as_sstring(); + object_write_bencode_c_value(output, value.size()); + object_write_bencode_c_char(output, ':'); + object_write_bencode_c_string(output, value.c_str(), value.size()); + + } else { + throw internal_error("staticMap_write_bencode_c_values received key keys with invalid values type."); + } +} + object_buffer_t object_write_bencode_c(object_write_t writeFunc, void* data, object_buffer_t buffer, const Object* object) { object_write_data_t output; @@ -327,6 +549,32 @@ object_write_bencode_c(object_write_t writeFunc, void* data, object_buffer_t buf } object_buffer_t +staticMap_write_bencode_c_wrap(object_write_t writeFunc, void* data, object_buffer_t buffer, const Object* values, const StaticMapKeys& map) { + object_write_data_t output; + output.writeFunc = writeFunc; + output.data = data; + output.buffer = buffer; + output.pos = buffer.first; + + staticMap_write_bencode_c_values(&output, values, map); +#ifdef USE_EXTRA_DEBUG + std::istringstream sstream; + sstream.imbue(std::locale::classic()); + sstream.str(std::string(output.buffer.first, output.pos)); + Object request; + sstream >> request; + if (sstream.fail()) + throw internal_error("staticMap_write_bencode_c_wrap failed to create valid bencode format."); +#endif + + // Don't flush the buffer. + if (output.pos == output.buffer.first) + return output.buffer; + + return output.writeFunc(output.data, object_buffer_t(output.buffer.first, output.pos)); +} + +object_buffer_t object_write_to_buffer(void* data, object_buffer_t buffer) { if (buffer.first == buffer.second) throw internal_error("object_write_to_buffer(...) buffer overflow."); @@ -352,4 +600,11 @@ object_write_to_stream(void* data, object_buffer_t buffer) { return buffer; } +object_buffer_t +object_write_to_size(void* data, object_buffer_t buffer) { + *reinterpret_cast(data) += std::distance(buffer.first, buffer.second); + + return buffer; +} + } diff --git a/src/torrent/object_stream.h b/src/torrent/object_stream.h index 41cf82a..3de5d82 100644 --- a/src/torrent/object_stream.h +++ b/src/torrent/object_stream.h @@ -43,6 +43,10 @@ namespace torrent { +template +class StaticMap; +class StaticMapKeys; + std::string object_sha1(const Object* object) LIBTORRENT_EXPORT; // Assumes the stream's locale has been set to POSIX or C. Max depth @@ -53,6 +57,18 @@ void object_read_bencode(std::istream* input, Object* object, uint32_t depth = 0 // Assumes the stream's locale has been set to POSIX or C. void object_write_bencode(std::ostream* output, const Object* object) LIBTORRENT_EXPORT; +// Convert buffer to static key map. Inlined because we don't want +// a separate wrapper function for each template argument. +template +inline const char* +staticMap_read_bencode(const char* buffer, const char* bufferEnd, StaticMap& map) { + return staticMap_read_bencode_c(buffer, bufferEnd, 0, map.values(), map.map(), false); +}; + +// Internal use only. +const char* +staticMap_read_bencode_c(const char* buffer, const char* bufferEnd, uint32_t depth, Object* values, const StaticMapKeys& keys, bool discard); + std::istream& operator >> (std::istream& input, Object& object) LIBTORRENT_EXPORT; std::ostream& operator << (std::ostream& output, const Object& object) LIBTORRENT_EXPORT; @@ -62,10 +78,22 @@ typedef object_buffer_t (*object_write_t)(void* data, object_buffer_t buffer); object_buffer_t object_write_bencode_c(object_write_t writeFunc, void* data, object_buffer_t buffer, const Object* object) LIBTORRENT_EXPORT; +template +inline object_buffer_t +staticMap_write_bencode_c(object_write_t writeFunc, void* data, object_buffer_t buffer, const StaticMap& object) { + return staticMap_write_bencode_c_wrap(writeFunc, data, buffer, object.values(), object.map()); +} + +// Internal use only. +object_buffer_t staticMap_write_bencode_c_wrap(object_write_t writeFunc, void* data, object_buffer_t buffer, const Object* values, const StaticMapKeys& keys) LIBTORRENT_EXPORT; + // To char buffer. 'data' is NULL. object_buffer_t object_write_to_buffer(void* data, object_buffer_t buffer) LIBTORRENT_EXPORT; object_buffer_t object_write_to_sha1(void* data, object_buffer_t buffer) LIBTORRENT_EXPORT; object_buffer_t object_write_to_stream(void* data, object_buffer_t buffer) LIBTORRENT_EXPORT; + +// Measures bencode size, 'data' is uint64_t*. +object_buffer_t object_write_to_size(void* data, object_buffer_t buffer) LIBTORRENT_EXPORT; } #endif diff --git a/src/torrent/simple_string.h b/src/torrent/simple_string.h new file mode 100644 index 0000000..8eaf3b7 --- /dev/null +++ b/src/torrent/simple_string.h @@ -0,0 +1,129 @@ +// libTorrent - BitTorrent library +// Copyright (C) 2005-2008, Jari Sundell +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +// In addition, as a special exception, the copyright holders give +// permission to link the code of portions of this program with the +// OpenSSL library under certain conditions as described in each +// individual source file, and distribute linked combinations +// including the two. +// +// You must obey the GNU General Public License in all respects for +// all of the code used other than OpenSSL. If you modify file(s) +// with this exception, you may extend this exception to your version +// of the file(s), but you are not obligated to do so. If you do not +// wish to do so, delete this exception statement from your version. +// If you delete this exception statement from all source files in the +// program, then also delete it here. +// +// Contact: Jari Sundell +// +// Skomakerveien 33 +// 3185 Skoppum, NORWAY + +// A simple string with no constructors (i.e. POD) plus a derived +// class with constructors and conversion operators. In most cases, +// SimpleString is the class to use, except in unions where it needs +// to be SimpleStringBase. +// +// For efficient conversion from C string literals, depends on the +// compiler optimizing strlen("literal") to an integer literal. +// Then a comparison with either a C string literal or a SimpleString +// literal is a memcmp call plus (if equal) a comparison of the lengths. + +#ifndef LIBTORRENT_SIMPLE_STRING_H +#define LIBTORRENT_SIMPLE_STRING_H + +#include +#include +#include +#include + +namespace torrent { + +// Simple string base class (POD). +struct LIBTORRENT_EXPORT SimpleStringBase { + int cmp(const SimpleStringBase& other) const; + + char operator [] (size_t index) const { return m_data[index]; } + + const char* begin() const { return m_data; } + const char* end() const { return m_data + m_length; } + + // NOTE: Unlike std::string, SimpleString's c_str() is NOT guaranteed to be zero-terminated! + const char* c_str() const { return m_data; } + const char* data() const { return m_data; } + + bool empty() const { return !m_length; } + size_t length() const { return m_length; } + size_t size() const { return m_length; } + + std::string str() const { return std::string(m_data, m_length); } + std::string substr(size_t pos = 0, size_t n = npos) const { return std::string(m_data + pos, std::min(m_length - pos, n)); } + + // Allocates a copy of the string and returns it. + SimpleStringBase copy() const; + + static const size_t npos = static_cast(-1); + +protected: + const char* m_data; + size_t m_length; +}; + +// Conversion helper class, we don't want constructors +// in the base class to be able to put it in a union. +struct LIBTORRENT_EXPORT SimpleString : public SimpleStringBase { + typedef SimpleStringBase base_type; + + SimpleString() { m_data = ""; m_length = 0; } + SimpleString(const base_type& s) { m_data = s.c_str(); m_length = s.length(); } + SimpleString(const std::string& s) { m_data = s.c_str(); m_length = s.length(); } + SimpleString(const char* s) { m_data = s; m_length = strlen(s); } + SimpleString(const char* s, size_t l) { m_data = s; m_length = l; } +}; + +inline int +SimpleStringBase::cmp(const SimpleStringBase& other) const { + int cmp = memcmp(m_data, other.m_data, std::min(m_length, other.m_length)); + return cmp ? cmp : m_length - other.m_length; +} + +inline SimpleStringBase +SimpleStringBase::copy() const { + char* data = new char[m_length + 1]; + memcpy(data, m_data, m_length); + data[m_length] = 0; + return SimpleString(data, m_length); +} + +inline bool operator == (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) == 0; } +inline bool operator != (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) != 0; } +inline bool operator <= (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) <= 0; } +inline bool operator < (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) < 0; } +inline bool operator >= (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) >= 0; } +inline bool operator > (const SimpleStringBase& one, const SimpleStringBase& other) { return one.cmp(other) > 0; } + +inline bool operator == (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) == 0; } +inline bool operator != (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) != 0; } +inline bool operator <= (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) <= 0; } +inline bool operator < (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) < 0; } +inline bool operator >= (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) >= 0; } +inline bool operator > (const SimpleStringBase& one, const char* other) { return one.cmp(SimpleString(other)) > 0; } + +} + +#endif diff --git a/src/torrent/static_map.cc b/src/torrent/static_map.cc new file mode 100644 index 0000000..b71f257 --- /dev/null +++ b/src/torrent/static_map.cc @@ -0,0 +1,123 @@ +// libTorrent - BitTorrent library +// Copyright (C) 2005-2008, Jari Sundell +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +// In addition, as a special exception, the copyright holders give +// permission to link the code of portions of this program with the +// OpenSSL library under certain conditions as described in each +// individual source file, and distribute linked combinations +// including the two. +// +// You must obey the GNU General Public License in all respects for +// all of the code used other than OpenSSL. If you modify file(s) +// with this exception, you may extend this exception to your version +// of the file(s), but you are not obligated to do so. If you do not +// wish to do so, delete this exception statement from your version. +// If you delete this exception statement from all source files in the +// program, then also delete it here. +// +// Contact: Jari Sundell +// +// Skomakerveien 33 +// 3185 Skoppum, NORWAY + +#include "config.h" + +#include "static_map.h" + +namespace torrent { + +inline int +StaticMapKeys::check_key_order(SimpleString key) { + int cmp = empty() ? -1 : back().key().cmp(key); + if (cmp > 0) { + if (type() == TYPE_LIST) + cmp = -1; // List order is given by indices, not alphabetically. + else + throw internal_error("StaticMapKeys::StaticMapKeys() called with unsorted keys."); + } + + return cmp; +} + +StaticMapKeys::StaticMapKeys(const mapping_type* key_list, size_t length) + : m_key(SimpleString("root", 4)), + m_indexBegin(0), + m_indexEnd(0), + m_type(key_list[0].key[0] == '[' ? TYPE_LIST : TYPE_DICT) { + + for (size_t index = 0; index < length; index++, key_list++) { + if (key_list->index != index) + throw internal_error("StaticMapKeys::StaticMapKeys() used with list not in index order."); + + StaticMapKeys* curMap = this; + const char* key = key_list->key; + while (key != NULL && *key) { + curMap->set_end(index + 1); + + const char* sep = key + 1 + strcspn(key + 1, ":["); + SimpleString keyStr(key, sep - key); + + // New key, in correct order? Or same key as before? + int cmp = curMap->check_key_order(keyStr); + + if (sep[0] == 0) { + curMap->insert(curMap->end(), StaticMapKeys(keyStr, TYPE_VALUE, index, index + 1)); + break; + + } else if (sep[0] == '[' && sep[1] == ']' && sep[2] == 0) { + curMap->insert(curMap->end(), StaticMapKeys(keyStr, TYPE_BENCODE_LIST, index, index + 1)); + break; + + } else if (sep[0] == ':' && sep[1] == ':' && sep[2] == 0) { + curMap->insert(curMap->end(), StaticMapKeys(keyStr, TYPE_BENCODE, index, index + 1)); + break; + } + + if (sep[0] == ':' && sep[1] == ':') { + if (cmp < 0) + curMap->insert(curMap->end(), StaticMapKeys(keyStr, TYPE_DICT, index, index + 1)); + else if (curMap->back().type() != TYPE_DICT) + throw internal_error("StaticMapKeys::StaticMapKeys() called with a mixed dictionary/list entry."); + + curMap = &curMap->back(); + key = sep + 2; + + } else if (sep[0] == '[' && sep[1] >= '0' && sep[1] <= '9') { + key = sep++; + while (*sep >= '0' && *sep <= '9') + ++sep; + if (*sep != ']') + throw internal_error("StaticMapKeys::StaticMapKeys() called with invalid list index."); + + if (cmp < 0) + curMap->insert(curMap->end(), StaticMapKeys(keyStr, TYPE_LIST, index, index + 1)); + else if (curMap->back().type() != TYPE_LIST) + throw internal_error("StaticMapKeys::StaticMapKeys() called with a mixed dictionary/list entry."); + + curMap = &curMap->back(); + + } else { + throw internal_error("StaticMapKeys::StaticMapKeys() called with unsupported key type."); + } + } + } + + if (index_end() != length) + throw internal_error("StaticMapKeys::StaticMapKeys() is missing values."); +} + +} diff --git a/src/torrent/static_map.h b/src/torrent/static_map.h new file mode 100644 index 0000000..d862f16 --- /dev/null +++ b/src/torrent/static_map.h @@ -0,0 +1,158 @@ +// libTorrent - BitTorrent library +// Copyright (C) 2005-2008, Jari Sundell +// +// This program is free software; you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation; either version 2 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program; if not, write to the Free Software +// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +// +// In addition, as a special exception, the copyright holders give +// permission to link the code of portions of this program with the +// OpenSSL library under certain conditions as described in each +// individual source file, and distribute linked combinations +// including the two. +// +// You must obey the GNU General Public License in all respects for +// all of the code used other than OpenSSL. If you modify file(s) +// with this exception, you may extend this exception to your version +// of the file(s), but you are not obligated to do so. If you do not +// wish to do so, delete this exception statement from your version. +// If you delete this exception statement from all source files in the +// program, then also delete it here. +// +// Contact: Jari Sundell +// +// Skomakerveien 33 +// 3185 Skoppum, NORWAY + +#ifndef LIBTORRENT_STATIC_MAP_H +#define LIBTORRENT_STATIC_MAP_H + +#include +#include +#include +#include + +// StaticMap: holds a pre-defined subset of possible bencode keys and stores +// their values in a flat array for fast decoding, key access and encoding. +// Makes no copies, so the underlying data buffer must outlive the map object. + +// With this, the complexity for bencoding and bdecoding a StaticMap object +// is O(n). The access to any of the pre-defined keys is O(1). Access to +// other keys is not supported, they are dropped while bdecoding. Decoded +// Object types are either VALUE, SSTRING or NONE (if key was not present). + +// To use, define an enum of all required keys, and use this type along with +// the number of possible keys in the StaticMap template arguments. Define +// the enum -> key string as array of StaticMapKeys::mapping_type. Define +// the static keyMap variable, most simply by defining base_type in your +// derived map class, like this: +// template<> const Derived::key_map_init Derived::base_type::keyMap(key_list); + +// The argument of the constructor of this static keyMap object is a list +// of mapping_type entries. For efficiency, they must be ordered in +// increasing number of the index, and increasing alphabetical order +// (or more specifically, the bencode order) at the same time. In other words, +// the original enum must also be in alphabetical order of the keys the enum +// values refer to. + +// Format of the key specifications ("..." may contain any number of further keys): +// "foo::..." makes foo a bencode dictionary +// "foo[0]..." makes foo a bencode list +// "foo::" makes foo an undecoded bencode value (may contain arbitrary bencode data) +// "foo[]" makes foo an undecoded list of bencode values (like the above but adding the 'l' and 'e' indicators) +// "foo" makes foo an integer or string value (automatic) +// +// Examples: +// "baz" refers to a single value for key "baz" +// "foo::a[0]::bar" refers to a single value for key "bar" in the dictionary at index 0 of the list for key "a" in dictionary "foo" +// "foo::a[1]" refers to a single value at index 1 of the list for key "a" in the dictionary "foo" +// "zoo::" refers to a bdecoded value for key "zoo" +// +// If the four values are 4, 5, "6" and 7, this would be bencoded as d3:bazi4e3:food1:ald3:bari5ee1:6ee3:zooi7ee +// +// Note that sparse lists are not possible, you must explicitly specify all needed entries starting from index 0, +// and when bencoding, the first unset value terminates the list. + +namespace torrent { + +// Hierarchical structure mapping bencode keys to flat array indices. +class LIBTORRENT_EXPORT StaticMapKeys : public std::vector { +public: + typedef std::vector base_type; + + struct mapping_type { + size_t index; + const char* key; + }; + + enum value_type { + TYPE_VALUE, + TYPE_LIST, + TYPE_DICT, + TYPE_BENCODE, + TYPE_BENCODE_LIST, + }; + + StaticMapKeys(const mapping_type* key_list, size_t length); + + void set_end(size_t end) { m_indexEnd = end; } + + size_t index_begin() const { return m_indexBegin; } + size_t index_end() const { return m_indexEnd; } + + value_type type() const { return m_type; } + + SimpleString key() const { return m_key; } + +private: + StaticMapKeys(SimpleString key, value_type type, size_t begin, size_t end) + : m_key(key), m_indexBegin(begin), m_indexEnd(end), m_type(type) {} + + int check_key_order(SimpleString key); + + SimpleString m_key; + size_t m_indexBegin; + size_t m_indexEnd; + value_type m_type; +}; + +template +class LIBTORRENT_EXPORT StaticMap { +public: + typedef Object& value_type; + typedef tmpl_key_type key_type; + typedef StaticMapKeys key_map_type; + typedef Object list_type[tmpl_length]; + + Object& operator [] (key_type key) { return m_values[key]; } + const Object& operator [] (key_type key) const { return m_values[key]; } + + const key_map_type& map() const { return keyMap; } + + list_type& values() { return m_values; } + const list_type& values() const { return m_values; } + + static const size_t length = tmpl_length; + +private: + struct key_map_init : public key_map_type { + key_map_init(key_map_type::mapping_type* key_list) : key_map_type(key_list, tmpl_length) {}; + }; + static const key_map_init keyMap; + + list_type m_values; +}; + +} + +#endif diff --git a/src/torrent/torrent.cc b/src/torrent/torrent.cc index e8ffbac..47027cc 100644 --- a/src/torrent/torrent.cc +++ b/src/torrent/torrent.cc @@ -350,11 +350,22 @@ download_add(Object* object) { ctor.initialize(*object); - std::string infoHash = object_sha1(&object->get_key("info")); + std::string infoHash; + if (download->info()->is_meta_download()) + infoHash = object->get_key("info").get_key("pieces").as_string(); + else + infoHash = object_sha1(&object->get_key("info")); if (manager->download_manager()->find(infoHash) != manager->download_manager()->end()) throw input_error("Info hash already used by another torrent."); + if (!download->info()->is_meta_download()) { + char buffer[1024]; + uint64_t metadata_size = 0; + object_write_bencode_c(&object_write_to_size, &metadata_size, object_buffer_t(buffer, buffer + sizeof(buffer)), &object->get_key("info")); + download->main()->set_metadata_size(metadata_size); + } + download->set_hash_queue(manager->hash_queue()); download->initialize(infoHash, PEER_NAME + rak::generate_random(20 - std::string(PEER_NAME).size())); diff --git a/src/tracker/tracker_dht.cc b/src/tracker/tracker_dht.cc index c63ce58..309fcf2 100644 --- a/src/tracker/tracker_dht.cc +++ b/src/tracker/tracker_dht.cc @@ -115,13 +115,11 @@ TrackerDht::type() const { } void -TrackerDht::receive_peers(const Object& peer_list) { +TrackerDht::receive_peers(SimpleString peers) { if (!is_busy()) throw internal_error("TrackerDht::receive_peers called while not busy."); - Object::list_type peers = peer_list.as_list(); - for (Object::list_type::const_iterator itr = peers.begin(); itr != peers.end(); ++itr) - m_peers.parse_address_compact(itr->as_string()); + m_peers.parse_address_bencode(peers); } void diff --git a/src/tracker/tracker_dht.h b/src/tracker/tracker_dht.h index d197e61..d096b46 100644 --- a/src/tracker/tracker_dht.h +++ b/src/tracker/tracker_dht.h @@ -71,7 +71,7 @@ public: bool has_peers() const { return !m_peers.empty(); } - void receive_peers(const Object& peer_list); + void receive_peers(SimpleString peers); void receive_success(); void receive_failed(const char* msg); void receive_progress(int replied, int contacted);