From f749f2d21c1b47e6dcd626633acff764a4484b99 Mon Sep 17 00:00:00 2001 From: Dhammika Pathirana Date: Mon, 13 Dec 2010 15:40:26 +0100 Subject: add basic uri validations Signed-off-by: Dhammika Pathirana --- src/socket_base.cpp | 50 ++++++++++++++++++++++++++++---------------------- 1 file changed, 28 insertions(+), 22 deletions(-) (limited to 'src/socket_base.cpp') diff --git a/src/socket_base.cpp b/src/socket_base.cpp index 2fe7bfd..248c1e3 100644 --- a/src/socket_base.cpp +++ b/src/socket_base.cpp @@ -141,6 +141,26 @@ void zmq::socket_base_t::stop () send_stop (); } +int zmq::socket_base_t::parse_uri (const char *uri_, + std::string &protocol_, std::string &address_) +{ + zmq_assert (uri_ != NULL); + + std::string uri (uri_); + std::string::size_type pos = uri.find ("://"); + if (pos == std::string::npos) { + errno = EINVAL; + return -1; + } + protocol_ = uri.substr (0, pos); + address_ = uri.substr (pos + 3); + if (protocol_.empty () || address_.empty ()) { + errno = EINVAL; + return -1; + } + return 0; +} + int zmq::socket_base_t::check_protocol (const std::string &protocol_) { // First check out whether the protcol is something we are aware of. @@ -272,18 +292,11 @@ int zmq::socket_base_t::bind (const char *addr_) // Parse addr_ string. std::string protocol; std::string address; - { - std::string addr (addr_); - std::string::size_type pos = addr.find ("://"); - if (pos == std::string::npos) { - errno = EINVAL; - return -1; - } - protocol = addr.substr (0, pos); - address = addr.substr (pos + 3); - } + int rc = parse_uri (addr_, protocol, address); + if (rc != 0) + return -1; - int rc = check_protocol (protocol); + rc = check_protocol (protocol); if (rc != 0) return -1; @@ -334,18 +347,11 @@ int zmq::socket_base_t::connect (const char *addr_) // Parse addr_ string. std::string protocol; std::string address; - { - std::string addr (addr_); - std::string::size_type pos = addr.find ("://"); - if (pos == std::string::npos) { - errno = EINVAL; - return -1; - } - protocol = addr.substr (0, pos); - address = addr.substr (pos + 3); - } + int rc = parse_uri (addr_, protocol, address); + if (rc != 0) + return -1; - int rc = check_protocol (protocol); + rc = check_protocol (protocol); if (rc != 0) return -1; -- cgit v1.2.3