diff options
Diffstat (limited to 'src/net')
-rw-r--r-- | src/net/fwd.hpp | 33 | ||||
-rw-r--r-- | src/net/ip.cpp | 120 | ||||
-rw-r--r-- | src/net/ip.hpp | 166 | ||||
-rw-r--r-- | src/net/ip.py | 14 | ||||
-rw-r--r-- | src/net/ip_test.cpp | 359 | ||||
-rw-r--r-- | src/net/packets.cpp | 106 | ||||
-rw-r--r-- | src/net/packets.hpp | 585 | ||||
-rw-r--r-- | src/net/socket.cpp | 487 | ||||
-rw-r--r-- | src/net/socket.hpp | 177 | ||||
-rw-r--r-- | src/net/timer.cpp | 221 | ||||
-rw-r--r-- | src/net/timer.hpp | 51 | ||||
-rw-r--r-- | src/net/timer.t.hpp | 164 |
12 files changed, 2483 insertions, 0 deletions
diff --git a/src/net/fwd.hpp b/src/net/fwd.hpp new file mode 100644 index 0000000..2097772 --- /dev/null +++ b/src/net/fwd.hpp @@ -0,0 +1,33 @@ +#pragma once +// net/fwd.hpp - list of type names for net lib +// +// Copyright © 2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "../sanity.hpp" + + +namespace tmwa +{ +class Session; + +class IP4Address; + +class TimerData; + +enum class RecvResult; +} // namespace tmwa diff --git a/src/net/ip.cpp b/src/net/ip.cpp new file mode 100644 index 0000000..bfc2028 --- /dev/null +++ b/src/net/ip.cpp @@ -0,0 +1,120 @@ +#include "ip.hpp" +// ip.cpp - Implementation of IP address functions. +// +// Copyright © 2013 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "../strings/xstring.hpp" +#include "../strings/vstring.hpp" + +#include "../io/cxxstdio.hpp" + +#include "../mmo/extract.hpp" + +#include "../poison.hpp" + + +namespace tmwa +{ +bool extract(XString str, IP4Address *rv) +{ + if (str.endswith('.')) + return false; + uint8_t buf[4]; + if (extract(str, record<'.'>(&buf[0], &buf[1], &buf[2], &buf[3]))) + { + *rv = IP4Address(buf); + return true; + } + return false; +} + +bool extract(XString str, IP4Mask *rv) +{ + IP4Address a, m; + unsigned b; + XString l, r; + if (str.endswith('/')) + return false; + if (extract(str, record<'/'>(&l, &r))) + { + // a.b.c.d/e.f.g.h or a.b.c.d/n + if (!extract(l, &a)) + return false; + if (extract(r, &m)) + { + *rv = IP4Mask(a, m); + return true; + } + if (!extract(r, &b) || b > 32) + return false; + } + else + { + // a.b.c.d or a.b.c.d. or a.b.c. or a.b. or a. + if (extract(str, &a)) + { + *rv = IP4Mask(a, IP4_BROADCAST); + return true; + } + if (!str.endswith('.')) + return false; + uint8_t d[4] {}; + if (extract(str, record<'.'>(&d[0], &d[1], &d[2], &d[3]))) + b = 32; + else if (extract(str, record<'.'>(&d[0], &d[1], &d[2]))) + b = 24; + else if (extract(str, record<'.'>(&d[0], &d[1]))) + b = 16; + else if (extract(str, record<'.'>(&d[0]))) + b = 8; + else + return false; + a = IP4Address(d); + } + // a is set; need to construct m from b + if (b == 0) + m = IP4Address(); + else if (b == 32) + m = IP4_BROADCAST; + else + { + uint32_t s = -1; + s <<= (32 - b); + m = IP4Address({ + static_cast<uint8_t>(s >> 24), + static_cast<uint8_t>(s >> 16), + static_cast<uint8_t>(s >> 8), + static_cast<uint8_t>(s >> 0), + }); + } + *rv = IP4Mask(a, m); + return true; +} + +VString<15> convert_for_printf(IP4Address a_) +{ + const uint8_t *a = a_.bytes(); + return STRNPRINTF(16, "%hhu.%hhu.%hhu.%hhu"_fmt, a[0], a[1], a[2], a[3]); +} + +VString<31> convert_for_printf(IP4Mask a) +{ + return STRNPRINTF(32, "%s/%s"_fmt, + a.addr(), a.mask()); +} +} // namespace tmwa diff --git a/src/net/ip.hpp b/src/net/ip.hpp new file mode 100644 index 0000000..e9e71f4 --- /dev/null +++ b/src/net/ip.hpp @@ -0,0 +1,166 @@ +#pragma once +// ip.hpp - classes to deal with IP addresses. +// +// Copyright © 2013 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "fwd.hpp" + +#include <netinet/in.h> + +#include <cstddef> +#include <cstdint> + +#include "../strings/fwd.hpp" + + +namespace tmwa +{ +// TODO - in the long run ports belong here also +// and of course, IPv6 stuff. +// But what about unix socket addresses? + +/// Helper function +template<class T, size_t n> +constexpr +bool _ce_a_lt(T (&a)[n], T (&b)[n], size_t i=0) +{ + return (i != n + && (a[i] < b[i] + || (a[i] == b[i] + && _ce_a_lt(a, b, i + 1)))); +} + +/// A 32-bit Ipv4 address. Does not include a port. +/// Guaranteed to be laid out like the network wants. +class IP4Address +{ + uint8_t _addr[4]; +public: + constexpr + IP4Address() + : _addr{} + {} + constexpr explicit + IP4Address(const uint8_t (&a)[4]) + : _addr{a[0], a[1], a[2], a[3]} + {} + explicit + IP4Address(struct in_addr addr) + { + static_assert(sizeof(addr) == sizeof(_addr), "4 bytes"); + *this = IP4Address(reinterpret_cast<const uint8_t (&)[4]>(addr)); + } + explicit + operator struct in_addr() const + { + return reinterpret_cast<const struct in_addr&>(_addr); + } + + constexpr friend + IP4Address operator & (IP4Address l, IP4Address r) + { + return IP4Address({ + static_cast<uint8_t>(l._addr[0] & r._addr[0]), + static_cast<uint8_t>(l._addr[1] & r._addr[1]), + static_cast<uint8_t>(l._addr[2] & r._addr[2]), + static_cast<uint8_t>(l._addr[3] & r._addr[3]), + }); + } + + IP4Address& operator &= (IP4Address m) + { return *this = *this & m; } + + const uint8_t *bytes() const + { return _addr; } + + constexpr friend + bool operator < (IP4Address l, IP4Address r) + { + return _ce_a_lt(l._addr, r._addr); + } + + constexpr friend + bool operator > (IP4Address l, IP4Address r) + { + return _ce_a_lt(r._addr, l._addr); + } + + constexpr friend + bool operator >= (IP4Address l, IP4Address r) + { + return !_ce_a_lt(l._addr, r._addr); + } + + constexpr friend + bool operator <= (IP4Address l, IP4Address r) + { + return !_ce_a_lt(r._addr, l._addr); + } + + constexpr friend + bool operator == (IP4Address l, IP4Address r) + { + return !(l < r || r < l); + } + + constexpr friend + bool operator != (IP4Address l, IP4Address r) + { + return (l < r || r < l); + } +}; + +class IP4Mask +{ + IP4Address _addr, _mask; +public: + constexpr + IP4Mask() : _addr(), _mask() + {} + constexpr + IP4Mask(IP4Address a, IP4Address m) : _addr(a & m), _mask(m) + {} + + constexpr + IP4Address addr() const + { return _addr; } + constexpr + IP4Address mask() const + { return _mask; } + + constexpr + bool covers(IP4Address a) const + { + return (a & _mask) == _addr; + } +}; + + +constexpr +IP4Address IP4_LOCALHOST({127, 0, 0, 1}); +constexpr +IP4Address IP4_BROADCAST({255, 255, 255, 255}); + + +VString<15> convert_for_printf(IP4Address a); +VString<31> convert_for_printf(IP4Mask m); + +bool extract(XString str, IP4Address *iv); + +bool extract(XString str, IP4Mask *iv); +} // namespace tmwa diff --git a/src/net/ip.py b/src/net/ip.py new file mode 100644 index 0000000..bcf90a2 --- /dev/null +++ b/src/net/ip.py @@ -0,0 +1,14 @@ +class IP4Address(object): + ''' print an IP4Address + ''' + __slots__ = ('_value') + name = 'tmwa::IP4Address' + enabled = True + + def __init__(self, value): + self._value = value + + def to_string(self): + addr = self._value['_addr'] + addr = tuple(int(addr[i]) for i in range(4)) + return '%d.%d.%d.%d' % addr diff --git a/src/net/ip_test.cpp b/src/net/ip_test.cpp new file mode 100644 index 0000000..419dc03 --- /dev/null +++ b/src/net/ip_test.cpp @@ -0,0 +1,359 @@ +#include "ip.hpp" +// ip_test.cpp - Testsuite for implementation of IP address functions. +// +// Copyright © 2013 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include <gtest/gtest.h> + +#include "../strings/vstring.hpp" +#include "../strings/literal.hpp" + +#include "../io/cxxstdio.hpp" + +#include "../poison.hpp" + + +namespace tmwa +{ +#define CB(X) (std::integral_constant<bool, (X)>::value) +TEST(ip4addr, cmp) +{ + constexpr static + IP4Address a = IP4_LOCALHOST; + constexpr static + IP4Address b = IP4_BROADCAST; + + EXPECT_FALSE(CB(a < a)); + EXPECT_TRUE (CB(a < b)); + EXPECT_FALSE(CB(b < a)); + EXPECT_FALSE(CB(b < b)); + + EXPECT_FALSE(CB(a > a)); + EXPECT_FALSE(CB(a > b)); + EXPECT_TRUE (CB(b > a)); + EXPECT_FALSE(CB(b > b)); + + EXPECT_TRUE (CB(a <= a)); + EXPECT_TRUE (CB(a <= b)); + EXPECT_FALSE(CB(b <= a)); + EXPECT_TRUE (CB(b <= b)); + + EXPECT_TRUE (CB(a >= a)); + EXPECT_FALSE(CB(a >= b)); + EXPECT_TRUE (CB(b >= a)); + EXPECT_TRUE (CB(b >= b)); + + EXPECT_TRUE (CB(a == a)); + EXPECT_FALSE(CB(a == b)); + EXPECT_FALSE(CB(b == a)); + EXPECT_TRUE (CB(b == b)); + + EXPECT_FALSE(CB(a != a)); + EXPECT_TRUE (CB(a != b)); + EXPECT_TRUE (CB(b != a)); + EXPECT_FALSE(CB(b != b)); +} + +TEST(ip4addr, str) +{ + IP4Address a; + EXPECT_EQ("0.0.0.0"_s, STRNPRINTF(17, "%s"_fmt, a)); + EXPECT_EQ("127.0.0.1"_s, STRNPRINTF(17, "%s"_fmt, IP4_LOCALHOST)); + EXPECT_EQ("255.255.255.255"_s, STRNPRINTF(17, "%s"_fmt, IP4_BROADCAST)); +} + +TEST(ip4addr, extract) +{ + IP4Address a; + EXPECT_TRUE(extract("0.0.0.0"_s, &a)); + EXPECT_EQ("0.0.0.0"_s, STRNPRINTF(16, "%s"_fmt, a)); + EXPECT_TRUE(extract("127.0.0.1"_s, &a)); + EXPECT_EQ("127.0.0.1"_s, STRNPRINTF(16, "%s"_fmt, a)); + EXPECT_TRUE(extract("255.255.255.255"_s, &a)); + EXPECT_EQ("255.255.255.255"_s, STRNPRINTF(16, "%s"_fmt, a)); + EXPECT_TRUE(extract("1.2.3.4"_s, &a)); + EXPECT_EQ("1.2.3.4"_s, STRNPRINTF(16, "%s"_fmt, a)); + + EXPECT_FALSE(extract("1.2.3.4.5"_s, &a)); + EXPECT_FALSE(extract("1.2.3.4."_s, &a)); + EXPECT_FALSE(extract("1.2.3."_s, &a)); + EXPECT_FALSE(extract("1.2.3"_s, &a)); + EXPECT_FALSE(extract("1.2."_s, &a)); + EXPECT_FALSE(extract("1.2"_s, &a)); + EXPECT_FALSE(extract("1."_s, &a)); + EXPECT_FALSE(extract("1"_s, &a)); + EXPECT_FALSE(extract(""_s, &a)); +} + + +TEST(ip4mask, body) +{ + IP4Mask m; + EXPECT_EQ(IP4Address(), m.addr()); + EXPECT_EQ(IP4Address(), m.mask()); + m = IP4Mask(IP4_LOCALHOST, IP4_BROADCAST); + EXPECT_EQ(IP4_LOCALHOST, m.addr()); + EXPECT_EQ(IP4_BROADCAST, m.mask()); +} + +TEST(ip4mask, str) +{ + IP4Mask m; + EXPECT_EQ("0.0.0.0/0.0.0.0"_s, STRNPRINTF(33, "%s"_fmt, m)); + m = IP4Mask(IP4_LOCALHOST, IP4_BROADCAST); + EXPECT_EQ("127.0.0.1/255.255.255.255"_s, STRNPRINTF(33, "%s"_fmt, m)); +} + +TEST(ip4mask, extract) +{ + IP4Mask m; + EXPECT_FALSE(extract("9.8.7.6/33"_s, &m)); + EXPECT_FALSE(extract("9.8.7.6.5"_s, &m)); + EXPECT_FALSE(extract("9.8.7.6/"_s, &m)); + EXPECT_FALSE(extract("9.8.7"_s, &m)); + EXPECT_FALSE(extract("9.8"_s, &m)); + EXPECT_FALSE(extract("9"_s, &m)); + + EXPECT_TRUE(extract("127.0.0.1"_s, &m)); + EXPECT_EQ("127.0.0.1/255.255.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("127.0.0.1."_s, &m)); + EXPECT_EQ("127.0.0.1/255.255.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("127.0.0."_s, &m)); + EXPECT_EQ("127.0.0.0/255.255.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("127.0."_s, &m)); + EXPECT_EQ("127.0.0.0/255.255.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("127."_s, &m)); + EXPECT_EQ("127.0.0.0/255.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("1.2.3.4/255.255.255.255"_s, &m)); + EXPECT_EQ("1.2.3.4/255.255.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.2.3.0/255.255.255.0"_s, &m)); + EXPECT_EQ("1.2.3.0/255.255.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.2.0.4/255.255.0.255"_s, &m)); + EXPECT_EQ("1.2.0.4/255.255.0.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.2.0.0/255.255.0.0"_s, &m)); + EXPECT_EQ("1.2.0.0/255.255.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.0.3.4/255.0.255.255"_s, &m)); + EXPECT_EQ("1.0.3.4/255.0.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.0.3.0/255.0.255.0"_s, &m)); + EXPECT_EQ("1.0.3.0/255.0.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.0.0.4/255.0.0.255"_s, &m)); + EXPECT_EQ("1.0.0.4/255.0.0.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("1.0.0.0/255.0.0.0"_s, &m)); + EXPECT_EQ("1.0.0.0/255.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.2.3.4/0.255.255.255"_s, &m)); + EXPECT_EQ("0.2.3.4/0.255.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.2.3.0/0.255.255.0"_s, &m)); + EXPECT_EQ("0.2.3.0/0.255.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.2.0.4/0.255.0.255"_s, &m)); + EXPECT_EQ("0.2.0.4/0.255.0.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.2.0.0/0.255.0.0"_s, &m)); + EXPECT_EQ("0.2.0.0/0.255.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.3.4/0.0.255.255"_s, &m)); + EXPECT_EQ("0.0.3.4/0.0.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.3.0/0.0.255.0"_s, &m)); + EXPECT_EQ("0.0.3.0/0.0.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.4/0.0.0.255"_s, &m)); + EXPECT_EQ("0.0.0.4/0.0.0.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/0.0.0.0"_s, &m)); + EXPECT_EQ("0.0.0.0/0.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + + // please don't do this + EXPECT_TRUE(extract("120.248.200.217/89.57.126.5"_s, &m)); + EXPECT_EQ("88.56.72.1/89.57.126.5"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("0.0.0.0/32"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.255"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("0.0.0.0/31"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.254"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/30"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.252"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/29"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.248"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/28"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.240"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/27"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.224"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/26"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.192"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/25"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.128"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/24"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.255.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("0.0.0.0/23"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.254.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/22"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.252.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/21"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.248.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/20"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.240.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/19"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.224.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/18"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.192.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/17"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.128.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/16"_s, &m)); + EXPECT_EQ("0.0.0.0/255.255.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("0.0.0.0/15"_s, &m)); + EXPECT_EQ("0.0.0.0/255.254.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/14"_s, &m)); + EXPECT_EQ("0.0.0.0/255.252.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/13"_s, &m)); + EXPECT_EQ("0.0.0.0/255.248.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/12"_s, &m)); + EXPECT_EQ("0.0.0.0/255.240.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/11"_s, &m)); + EXPECT_EQ("0.0.0.0/255.224.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/10"_s, &m)); + EXPECT_EQ("0.0.0.0/255.192.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/9"_s, &m)); + EXPECT_EQ("0.0.0.0/255.128.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/8"_s, &m)); + EXPECT_EQ("0.0.0.0/255.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + + EXPECT_TRUE(extract("0.0.0.0/7"_s, &m)); + EXPECT_EQ("0.0.0.0/254.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/6"_s, &m)); + EXPECT_EQ("0.0.0.0/252.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/5"_s, &m)); + EXPECT_EQ("0.0.0.0/248.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/4"_s, &m)); + EXPECT_EQ("0.0.0.0/240.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/3"_s, &m)); + EXPECT_EQ("0.0.0.0/224.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/2"_s, &m)); + EXPECT_EQ("0.0.0.0/192.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/1"_s, &m)); + EXPECT_EQ("0.0.0.0/128.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); + EXPECT_TRUE(extract("0.0.0.0/0"_s, &m)); + EXPECT_EQ("0.0.0.0/0.0.0.0"_s, STRNPRINTF(32, "%s"_fmt, m)); +} + +TEST(ip4mask, cover) +{ + IP4Address a; + IP4Address b = IP4_BROADCAST; + IP4Address l = IP4_LOCALHOST; + IP4Address h({127, 255, 255, 255}); + IP4Address p24l({10, 0, 0, 0}); + IP4Address p24h({10, 255, 255, 255}); + IP4Address p20l({172, 16, 0, 0}); + IP4Address p20h({172, 31, 255, 255}); + IP4Address p16l({192, 168, 0, 0}); + IP4Address p16h({192, 168, 255, 255}); + IP4Mask m; + EXPECT_TRUE(m.covers(a)); + EXPECT_TRUE(m.covers(b)); + EXPECT_TRUE(m.covers(l)); + EXPECT_TRUE(m.covers(h)); + EXPECT_TRUE(m.covers(p24l)); + EXPECT_TRUE(m.covers(p24h)); + EXPECT_TRUE(m.covers(p20l)); + EXPECT_TRUE(m.covers(p20h)); + EXPECT_TRUE(m.covers(p16l)); + EXPECT_TRUE(m.covers(p16h)); + m = IP4Mask(l, a); + EXPECT_TRUE(m.covers(a)); + EXPECT_TRUE(m.covers(b)); + EXPECT_TRUE(m.covers(l)); + EXPECT_TRUE(m.covers(h)); + EXPECT_TRUE(m.covers(p24l)); + EXPECT_TRUE(m.covers(p24h)); + EXPECT_TRUE(m.covers(p20l)); + EXPECT_TRUE(m.covers(p20h)); + EXPECT_TRUE(m.covers(p16l)); + EXPECT_TRUE(m.covers(p16h)); + m = IP4Mask(l, b); + EXPECT_FALSE(m.covers(a)); + EXPECT_FALSE(m.covers(b)); + EXPECT_TRUE(m.covers(l)); + EXPECT_FALSE(m.covers(h)); + EXPECT_FALSE(m.covers(p24l)); + EXPECT_FALSE(m.covers(p24h)); + EXPECT_FALSE(m.covers(p20l)); + EXPECT_FALSE(m.covers(p20h)); + EXPECT_FALSE(m.covers(p16l)); + EXPECT_FALSE(m.covers(p16h)); + + // but the really useful ones are with partial masks + m = IP4Mask(IP4Address({10, 0, 0, 0}), IP4Address({255, 0, 0, 0})); + EXPECT_FALSE(m.covers(a)); + EXPECT_FALSE(m.covers(b)); + EXPECT_FALSE(m.covers(l)); + EXPECT_FALSE(m.covers(h)); + EXPECT_TRUE(m.covers(p24l)); + EXPECT_TRUE(m.covers(p24h)); + EXPECT_FALSE(m.covers(p20l)); + EXPECT_FALSE(m.covers(p20h)); + EXPECT_FALSE(m.covers(p16l)); + EXPECT_FALSE(m.covers(p16h)); + EXPECT_FALSE(m.covers(IP4Address({9, 255, 255, 255}))); + EXPECT_FALSE(m.covers(IP4Address({11, 0, 0, 0}))); + m = IP4Mask(IP4Address({127, 0, 0, 0}), IP4Address({255, 0, 0, 0})); + EXPECT_FALSE(m.covers(a)); + EXPECT_FALSE(m.covers(b)); + EXPECT_TRUE(m.covers(l)); + EXPECT_TRUE(m.covers(h)); + EXPECT_FALSE(m.covers(p24l)); + EXPECT_FALSE(m.covers(p24h)); + EXPECT_FALSE(m.covers(p20l)); + EXPECT_FALSE(m.covers(p20h)); + EXPECT_FALSE(m.covers(p16l)); + EXPECT_FALSE(m.covers(p16h)); + EXPECT_FALSE(m.covers(IP4Address({126, 255, 255, 255}))); + EXPECT_FALSE(m.covers(IP4Address({128, 0, 0, 0}))); + m = IP4Mask(IP4Address({172, 16, 0, 0}), IP4Address({255, 240, 0, 0})); + EXPECT_FALSE(m.covers(a)); + EXPECT_FALSE(m.covers(b)); + EXPECT_FALSE(m.covers(l)); + EXPECT_FALSE(m.covers(h)); + EXPECT_FALSE(m.covers(p24l)); + EXPECT_FALSE(m.covers(p24h)); + EXPECT_TRUE(m.covers(p20l)); + EXPECT_TRUE(m.covers(p20h)); + EXPECT_FALSE(m.covers(p16l)); + EXPECT_FALSE(m.covers(p16h)); + EXPECT_FALSE(m.covers(IP4Address({172, 15, 255, 255}))); + EXPECT_FALSE(m.covers(IP4Address({172, 32, 0, 0}))); + m = IP4Mask(IP4Address({192, 168, 0, 0}), IP4Address({255, 255, 0, 0})); + EXPECT_FALSE(m.covers(a)); + EXPECT_FALSE(m.covers(b)); + EXPECT_FALSE(m.covers(l)); + EXPECT_FALSE(m.covers(h)); + EXPECT_FALSE(m.covers(p24l)); + EXPECT_FALSE(m.covers(p24h)); + EXPECT_FALSE(m.covers(p20l)); + EXPECT_FALSE(m.covers(p20h)); + EXPECT_TRUE(m.covers(p16l)); + EXPECT_TRUE(m.covers(p16h)); + EXPECT_FALSE(m.covers(IP4Address({192, 167, 255, 255}))); + EXPECT_FALSE(m.covers(IP4Address({192, 169, 0, 0}))); + + // OTOH this is crazy + EXPECT_TRUE(extract("120.248.200.217/89.57.126.5"_s, &m)); + EXPECT_TRUE(m.covers(IP4Address({120, 248, 200, 217}))); + EXPECT_TRUE(m.covers(IP4Address({88, 56, 72, 1}))); + EXPECT_FALSE(m.covers(IP4Address({88, 56, 72, 0}))); + EXPECT_FALSE(m.covers(IP4Address({88, 56, 72, 255}))); +} +} // namespace tmwa diff --git a/src/net/packets.cpp b/src/net/packets.cpp new file mode 100644 index 0000000..3cba856 --- /dev/null +++ b/src/net/packets.cpp @@ -0,0 +1,106 @@ +#include "packets.hpp" +// packets.cpp - palatable socket buffer accessors +// +// Copyright © 2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "../io/cxxstdio.hpp" +#include "../io/write.hpp" + +#include "../poison.hpp" + + +namespace tmwa +{ +size_t packet_avail(Session *s) +{ + return s->rdata_size - s->rdata_pos; +} + +bool packet_fetch(Session *s, size_t offset, Byte *data, size_t sz) +{ + if (packet_avail(s) < offset + sz) + return false; + const Byte *start = reinterpret_cast<const Byte *>(&s->rdata[s->rdata_pos + offset]); + const Byte *end = start + sz; + std::copy(start, end, data); + return true; +} +void packet_discard(Session *s, size_t sz) +{ + s->rdata_pos += sz; + + assert (s->rdata_size >= s->rdata_pos); +} +bool packet_send(Session *s, const Byte *data, size_t sz) +{ + if (s->wdata_size + sz > s->max_wdata) + { + realloc_fifo(s, s->max_rdata, s->max_wdata << 1); + PRINTF("socket: %d wdata expanded to %zu bytes.\n"_fmt, s, s->max_wdata); + } + if (!s->max_wdata || !s->wdata) + { + return false; + } + s->wdata_size += sz; + + Byte *end = reinterpret_cast<Byte *>(&s->wdata[s->wdata_size + 0]); + Byte *start = end - sz; + std::copy(data, data + sz, start); + return true; +} + +void packet_dump(io::WriteFile& logfp, Session *s) +{ + FPRINTF(logfp, + "---- 00-01-02-03-04-05-06-07 08-09-0A-0B-0C-0D-0E-0F\n"_fmt); + char tmpstr[16 + 1] {}; + int i; + for (i = 0; i < packet_avail(s); i++) + { + if ((i & 15) == 0) + FPRINTF(logfp, "%04X "_fmt, i); + Byte rfifob_ib; + packet_fetch(s, i, &rfifob_ib, 1); + uint8_t rfifob_i = rfifob_ib.value; + FPRINTF(logfp, "%02x "_fmt, rfifob_i); + if (rfifob_i > 0x1f) + tmpstr[i % 16] = rfifob_i; + else + tmpstr[i % 16] = '.'; + if ((i - 7) % 16 == 0) // -8 + 1 + FPRINTF(logfp, " "_fmt); + else if ((i + 1) % 16 == 0) + { + FPRINTF(logfp, " %s\n"_fmt, tmpstr); + std::fill(tmpstr + 0, tmpstr + 17, '\0'); + } + } + if (i % 16 != 0) + { + for (int j = i; j % 16 != 0; j++) + { + FPRINTF(logfp, " "_fmt); + if ((j - 7) % 16 == 0) // -8 + 1 + FPRINTF(logfp, " "_fmt); + } + FPRINTF(logfp, " %s\n"_fmt, tmpstr); + } + FPRINTF(logfp, "\n"_fmt); +} +} // namespace tmwa diff --git a/src/net/packets.hpp b/src/net/packets.hpp new file mode 100644 index 0000000..5cc377c --- /dev/null +++ b/src/net/packets.hpp @@ -0,0 +1,585 @@ +#pragma once +// packets.hpp - palatable socket buffer accessors +// +// Copyright © 2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "fwd.hpp" + +#include <vector> + +#include "../compat/cast.hpp" + +#include "../ints/little.hpp" + +#include "../io/fwd.hpp" + +// TODO ordering violation, should invert +#include "../proto2/fwd.hpp" + +#include "socket.hpp" + + +namespace tmwa +{ +struct Buffer +{ + std::vector<Byte> bytes; +}; + +enum class RecvResult +{ + Incomplete, + Complete, + Error, +}; + +enum class SendResult +{ + Success, + Fail, +}; + + +size_t packet_avail(Session *s); +void packet_dump(io::WriteFile& out, Session *s); + +bool packet_fetch(Session *s, size_t offset, Byte *data, size_t sz); +void packet_discard(Session *s, size_t sz); +bool packet_send(Session *s, const Byte *data, size_t sz); + +inline +bool packet_peek_id(Session *s, uint16_t *packet_id) +{ + Little16 id; + bool okay = packet_fetch(s, 0, reinterpret_cast<Byte *>(&id), 2); + if (okay) + { + if (!network_to_native(packet_id, id)) + { + s->set_eof(); + return false; + } + } + return okay; +} + +inline +void send_buffer(Session *s, const Buffer& buffer) +{ + bool ok = !buffer.bytes.empty() && packet_send(s, buffer.bytes.data(), buffer.bytes.size()); + if (!ok) + s->set_eof(); +} + +template<uint16_t id> +__attribute__((warn_unused_result)) +RecvResult net_recv_fpacket(Session *s, NetPacket_Fixed<id>& fixed) +{ + bool ok = packet_fetch(s, 0, reinterpret_cast<Byte *>(&fixed), sizeof(NetPacket_Fixed<id>)); + if (ok) + { + packet_discard(s, sizeof(NetPacket_Fixed<id>)); + return RecvResult::Complete; + } + return RecvResult::Incomplete; +} + +template<uint16_t id> +__attribute__((warn_unused_result)) +RecvResult net_recv_ppacket(Session *s, NetPacket_Payload<id>& payload) +{ + bool ok = packet_fetch(s, 0, reinterpret_cast<Byte *>(&payload), sizeof(NetPacket_Payload<id>)); + if (ok) + { + packet_discard(s, sizeof(NetPacket_Payload<id>)); + return RecvResult::Complete; + } + return RecvResult::Incomplete; +} + +template<uint16_t id> +__attribute__((warn_unused_result)) +RecvResult net_recv_vpacket(Session *s, NetPacket_Head<id>& head, std::vector<NetPacket_Repeat<id>>& repeat) +{ + bool ok = packet_fetch(s, 0, reinterpret_cast<Byte *>(&head), sizeof(NetPacket_Head<id>)); + if (ok) + { + Packet_Head<id> nat; + if (!network_to_native(&nat, head)) + return RecvResult::Error; + if (packet_avail(s) < nat.magic_packet_length) + return RecvResult::Incomplete; + if (nat.magic_packet_length < sizeof(NetPacket_Head<id>)) + return RecvResult::Error; + size_t bytes_repeat = nat.magic_packet_length - sizeof(NetPacket_Head<id>); + if (bytes_repeat % sizeof(NetPacket_Repeat<id>)) + return RecvResult::Error; + repeat.resize(bytes_repeat / sizeof(NetPacket_Repeat<id>)); + if (packet_fetch(s, sizeof(NetPacket_Head<id>), reinterpret_cast<Byte *>(repeat.data()), bytes_repeat)) + { + packet_discard(s, nat.magic_packet_length); + return RecvResult::Complete; + } + return RecvResult::Incomplete; + } + return RecvResult::Incomplete; +} + +template<uint16_t id> +__attribute__((warn_unused_result)) +RecvResult net_recv_opacket(Session *s, NetPacket_Head<id>& head, bool *has_opt, NetPacket_Option<id>& opt) +{ + bool ok = packet_fetch(s, 0, reinterpret_cast<Byte *>(&head), sizeof(NetPacket_Head<id>)); + if (ok) + { + Packet_Head<id> nat; + if (!network_to_native(&nat, head)) + return RecvResult::Error; + if (packet_avail(s) < nat.magic_packet_length) + return RecvResult::Incomplete; + if (nat.magic_packet_length < sizeof(NetPacket_Head<id>)) + return RecvResult::Error; + size_t bytes_repeat = nat.magic_packet_length - sizeof(NetPacket_Head<id>); + if (bytes_repeat % sizeof(NetPacket_Option<id>)) + return RecvResult::Error; + size_t has_opt_pls = bytes_repeat / sizeof(NetPacket_Option<id>); + if (has_opt_pls > 1) + return RecvResult::Error; + *has_opt = has_opt_pls; + if (!*has_opt || packet_fetch(s, sizeof(NetPacket_Head<id>), reinterpret_cast<Byte *>(&opt), sizeof(NetPacket_Option<id>))) + { + packet_discard(s, nat.magic_packet_length); + return RecvResult::Complete; + } + return RecvResult::Incomplete; + } + return RecvResult::Incomplete; +} + + +template<uint16_t id, uint16_t size> +Buffer create_fpacket(const Packet_Fixed<id>& fixed) +{ + static_assert(id == Packet_Fixed<id>::PACKET_ID, "Packet_Fixed<id>::PACKET_ID"); + static_assert(size == sizeof(NetPacket_Fixed<id>), "sizeof(NetPacket_Fixed<id>)"); + + Buffer buf; + buf.bytes.resize(sizeof(NetPacket_Fixed<id>)); + auto& net_fixed = reinterpret_cast<NetPacket_Fixed<id>&>( + *(buf.bytes.begin() + 0)); + if (!native_to_network(&net_fixed, fixed)) + { + return Buffer(); + } + return buf; +} + +template<uint16_t id> +Buffer create_ppacket(Packet_Payload<id>& payload) +{ + static_assert(id == Packet_Payload<id>::PACKET_ID, "Packet_Payload<id>::PACKET_ID"); + + if (id != 0x8000) + payload.magic_packet_length = sizeof(NetPacket_Payload<id>); + + Buffer buf; + buf.bytes.resize(sizeof(NetPacket_Payload<id>)); + auto& net_payload = reinterpret_cast<NetPacket_Payload<id>&>( + *(buf.bytes.begin() + 0)); + if (!native_to_network(&net_payload, payload)) + { + return Buffer(); + } + return buf; +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +Buffer create_vpacket(Packet_Head<id>& head, const std::vector<Packet_Repeat<id>>& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "sizeof(NetPacket_Head<id>)"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + + // since these are already allocated, can't overflow address space + size_t total_size = sizeof(NetPacket_Head<id>) + repeat.size() * sizeof(NetPacket_Repeat<id>); + // truncates + head.magic_packet_length = total_size; + if (head.magic_packet_length != total_size) + { + return Buffer(); + } + + Buffer buf; + buf.bytes.resize(total_size); + auto& net_head = reinterpret_cast<NetPacket_Head<id>&>( + *(buf.bytes.begin() + 0)); + if (!native_to_network(&net_head, head)) + { + return Buffer(); + } + for (size_t i = 0; i < repeat.size(); ++i) + { + auto& net_repeat_i = reinterpret_cast<NetPacket_Repeat<id>&>( + *(buf.bytes.begin() + + sizeof(NetPacket_Head<id>) + + i * sizeof(NetPacket_Repeat<id>))); + if (!native_to_network(&net_repeat_i, repeat[i])) + { + return Buffer(); + } + } + return buf; +} + +template<uint16_t id, uint16_t headsize, uint16_t optsize> +Buffer create_opacket(Packet_Head<id>& head, bool has_opt, const Packet_Option<id>& opt) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "sizeof(NetPacket_Head<id>)"); + static_assert(id == Packet_Option<id>::PACKET_ID, "Packet_Option<id>::PACKET_ID"); + static_assert(optsize == sizeof(NetPacket_Option<id>), "sizeof(NetPacket_Option<id>)"); + + // since these are already allocated, can't overflow address space + size_t total_size = sizeof(NetPacket_Head<id>) + has_opt * sizeof(NetPacket_Option<id>); + // truncates + head.magic_packet_length = total_size; + if (head.magic_packet_length != total_size) + { + return Buffer(); + } + + Buffer buf; + buf.bytes.resize(total_size); + + auto& net_head = reinterpret_cast<NetPacket_Head<id>&>( + *(buf.bytes.begin() + 0)); + if (!native_to_network(&net_head, head)) + { + return Buffer(); + } + if (has_opt) + { + auto& net_opt = reinterpret_cast<NetPacket_Option<id>&>( + *(buf.bytes.begin() + + sizeof(NetPacket_Head<id>))); + if (!native_to_network(&net_opt, opt)) + { + return Buffer(); + } + } + + return buf; +} + +template<uint16_t id, uint16_t size> +void send_fpacket(Session *s, const Packet_Fixed<id>& fixed) +{ + Buffer pkt = create_fpacket<id, size>(fixed); + send_buffer(s, pkt); +} + +template<uint16_t id> +void send_ppacket(Session *s, Packet_Payload<id>& payload) +{ + Buffer pkt = create_ppacket<id>(payload); + send_buffer(s, pkt); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +void send_vpacket(Session *s, Packet_Head<id>& head, const std::vector<Packet_Repeat<id>>& repeat) +{ + Buffer pkt = create_vpacket<id, headsize, repeatsize>(head, repeat); + send_buffer(s, pkt); +} + +template<uint16_t id, uint16_t headsize, uint16_t optsize> +void send_opacket(Session *s, Packet_Head<id>& head, bool has_opt, const Packet_Option<id>& opt) +{ + Buffer pkt = create_opacket<id, headsize, optsize>(head, has_opt, opt); + send_buffer(s, pkt); +} + +template<uint16_t id, uint16_t size> +__attribute__((warn_unused_result)) +RecvResult recv_fpacket(Session *s, Packet_Fixed<id>& fixed) +{ + static_assert(id == Packet_Fixed<id>::PACKET_ID, "Packet_Fixed<id>::PACKET_ID"); + static_assert(size == sizeof(NetPacket_Fixed<id>), "NetPacket_Fixed<id>"); + + NetPacket_Fixed<id> net_fixed; + RecvResult rv = net_recv_fpacket(s, net_fixed); + if (rv == RecvResult::Complete) + { + if (!network_to_native(&fixed, net_fixed)) + return RecvResult::Error; + assert (fixed.magic_packet_id == Packet_Fixed<id>::PACKET_ID); + } + return rv; +} + +template<uint16_t id> +__attribute__((warn_unused_result)) +RecvResult recv_ppacket(Session *s, Packet_Payload<id>& payload) +{ + static_assert(id == Packet_Payload<id>::PACKET_ID, "Packet_Payload<id>::PACKET_ID"); + + NetPacket_Payload<id> net_payload; + RecvResult rv = net_recv_ppacket(s, net_payload); + if (rv == RecvResult::Complete) + { + if (!network_to_native(&payload, net_payload)) + return RecvResult::Error; + assert (payload.magic_packet_id == Packet_Payload<id>::PACKET_ID); + if (id == 0x8000) + { + // 0x8000 is special + if (packet_avail(s) < payload.magic_packet_length) + return RecvResult::Incomplete; + payload.magic_packet_length = 4; + return RecvResult::Complete; + } + if (payload.magic_packet_length != sizeof(net_payload)) + return RecvResult::Error; + } + return rv; +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +__attribute__((warn_unused_result)) +RecvResult recv_vpacket(Session *s, Packet_Head<id>& head, std::vector<Packet_Repeat<id>>& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "NetPacket_Head<id>"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "NetPacket_Repeat<id>"); + + NetPacket_Head<id> net_head; + std::vector<NetPacket_Repeat<id>> net_repeat; + RecvResult rv = net_recv_vpacket(s, net_head, net_repeat); + if (rv == RecvResult::Complete) + { + if (!network_to_native(&head, net_head)) + return RecvResult::Error; + assert (head.magic_packet_id == Packet_Head<id>::PACKET_ID); + + repeat.resize(net_repeat.size()); + for (size_t i = 0; i < net_repeat.size(); ++i) + { + if (!network_to_native(&repeat[i], net_repeat[i])) + return RecvResult::Error; + } + } + return rv; +} + +template<uint16_t id, uint16_t headsize, uint16_t optsize> +__attribute__((warn_unused_result)) +RecvResult recv_opacket(Session *s, Packet_Head<id>& head, bool *has_opt, Packet_Option<id>& opt) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "NetPacket_Head<id>"); + static_assert(id == Packet_Option<id>::PACKET_ID, "Packet_Option<id>::PACKET_ID"); + static_assert(optsize == sizeof(NetPacket_Option<id>), "NetPacket_Option<id>"); + + NetPacket_Head<id> net_head; + NetPacket_Option<id> net_opt; + RecvResult rv = net_recv_opacket(s, net_head, has_opt, net_opt); + if (rv == RecvResult::Complete) + { + if (!network_to_native(&head, net_head)) + return RecvResult::Error; + assert (head.magic_packet_id == Packet_Head<id>::PACKET_ID); + + if (*has_opt) + { + if (!network_to_native(&opt, net_opt)) + return RecvResult::Error; + } + } + return rv; +} + + +// convenience for trailing strings + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +Buffer create_vpacket(Packet_Head<id>& head, const XString& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "NetPacket_Head<id>"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "NetPacket_Repeat<id>"); + static_assert(repeatsize == 1, "repeatsize"); + + // since it's already allocated, it can't overflow address space + size_t total_length = sizeof(NetPacket_Head<id>) + (repeat.size() + 1) * sizeof(NetPacket_Repeat<id>); + head.magic_packet_length = total_length; + if (head.magic_packet_length != total_length) + { + return Buffer(); + } + + Buffer buf; + buf.bytes.resize(total_length); + auto& net_head = reinterpret_cast<NetPacket_Head<id>&>( + *(buf.bytes.begin() + 0)); + std::vector<NetPacket_Repeat<id>> net_repeat(repeat.size() + 1); + if (!native_to_network(&net_head, head)) + { + return Buffer(); + } + for (size_t i = 0; i < repeat.size(); ++i) + { + auto& net_repeat_i = reinterpret_cast<NetPacket_Repeat<id>&>( + *(buf.bytes.begin() + + sizeof(NetPacket_Head<id>) + + i)); + net_repeat_i.c = Byte{static_cast<uint8_t>(repeat[i])}; + } + auto& net_repeat_repeat_size = reinterpret_cast<NetPacket_Repeat<id>&>( + *(buf.bytes.begin() + + sizeof(NetPacket_Head<id>) + + repeat.size())); + net_repeat_repeat_size.c = Byte{static_cast<uint8_t>('\0')}; + return buf; +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +void send_vpacket(Session *s, Packet_Head<id>& head, const XString& repeat) +{ + Buffer pkt = create_vpacket<id, headsize, repeatsize>(head, repeat); + send_buffer(s, pkt); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +__attribute__((warn_unused_result)) +RecvResult recv_vpacket(Session *s, Packet_Head<id>& head, AString& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "NetPacket_Head<id>"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "NetPacket_Repeat<id>"); + static_assert(repeatsize == 1, "repeatsize"); + + NetPacket_Head<id> net_head; + std::vector<NetPacket_Repeat<id>> net_repeat; + RecvResult rv = net_recv_vpacket(s, net_head, net_repeat); + assert (head.magic_packet_id == Packet_Head<id>::PACKET_ID); + if (rv == RecvResult::Complete) + { + if (!network_to_native(&head, net_head)) + return RecvResult::Error; + // reinterpret_cast is needed to correctly handle an empty vector + const char *begin = sign_cast<const char *>(net_repeat.data()); + const char *end = begin + net_repeat.size(); + end = std::find(begin, end, '\0'); + repeat = XString(begin, end, nullptr); + } + return rv; +} + + +// if there is nothing in the head but the id and length, use the below + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +Buffer create_packet_repeatonly(const std::vector<Packet_Repeat<id>>& v) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + + Packet_Head<id> head; + return create_vpacket<id, 4, repeatsize>(head, v); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +void send_packet_repeatonly(Session *s, const std::vector<Packet_Repeat<id>>& v) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + + Packet_Head<id> head; + send_vpacket<id, 4, repeatsize>(s, head, v); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +__attribute__((warn_unused_result)) +RecvResult recv_packet_repeatonly(Session *s, std::vector<Packet_Repeat<id>>& v) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + + Packet_Head<id> head; + return recv_vpacket<id, 4, repeatsize>(s, head, v); +} + + +// and the combination of both of the above + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +Buffer create_packet_repeatonly(const XString& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + static_assert(repeatsize == 1, "repeatsize"); + + Packet_Head<id> head; + return create_vpacket<id, 4, repeatsize>(head, repeat); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +void send_packet_repeatonly(Session *s, const XString& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + static_assert(repeatsize == 1, "repeatsize"); + + Packet_Head<id> head; + send_vpacket<id, 4, repeatsize>(s, head, repeat); +} + +template<uint16_t id, uint16_t headsize, uint16_t repeatsize> +__attribute__((warn_unused_result)) +RecvResult recv_packet_repeatonly(Session *s, AString& repeat) +{ + static_assert(id == Packet_Head<id>::PACKET_ID, "Packet_Head<id>::PACKET_ID"); + static_assert(headsize == sizeof(NetPacket_Head<id>), "repeat headsize"); + static_assert(headsize == 4, "repeat headsize"); + static_assert(id == Packet_Repeat<id>::PACKET_ID, "Packet_Repeat<id>::PACKET_ID"); + static_assert(repeatsize == sizeof(NetPacket_Repeat<id>), "sizeof(NetPacket_Repeat<id>)"); + static_assert(repeatsize == 1, "repeatsize"); + + Packet_Head<id> head; + return recv_vpacket<id, 4, repeatsize>(s, head, repeat); +} +} // namespace tmwa diff --git a/src/net/socket.cpp b/src/net/socket.cpp new file mode 100644 index 0000000..a01cd81 --- /dev/null +++ b/src/net/socket.cpp @@ -0,0 +1,487 @@ +#include "socket.hpp" +// socket.cpp - Network event system. +// +// Copyright © ????-2004 Athena Dev Teams +// Copyright © 2004-2011 The Mana World Development Team +// Copyright © 2011-2014 Ben Longbons <b.r.longbons@gmail.com> +// Copyright © 2013 MadCamel +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include <netinet/tcp.h> +#include <sys/socket.h> + +#include <fcntl.h> + +#include <cstdlib> + +#include <array> + +#include "../compat/memory.hpp" + +#include "../io/cxxstdio.hpp" + +// TODO get rid of ordering violations +#include "../mmo/utils.hpp" +#include "../mmo/core.hpp" + +#include "timer.hpp" + +#include "../poison.hpp" + + +namespace tmwa +{ +static +io::FD_Set readfds; +static +int fd_max; + +static +const uint32_t RFIFO_SIZE = 65536; +static +const uint32_t WFIFO_SIZE = 65536; + +DIAG_PUSH(); +DIAG_I(old_style_cast); +static +std::array<std::unique_ptr<Session>, FD_SETSIZE> session; +DIAG_POP(); + +Session::Session(SessionIO io, SessionParsers p) +: created() +, connected() +, eof() +, timed_close() +, rdata(), wdata() +, max_rdata(), max_wdata() +, rdata_size(), wdata_size() +, rdata_pos() +, client_ip() +, func_recv() +, func_send() +, func_parse() +, func_delete() +, for_inferior() +, session_data() +, fd() +{ + set_io(io); + set_parsers(p); +} +void Session::set_io(SessionIO io) +{ + func_send = io.func_send; + func_recv = io.func_recv; +} +void Session::set_parsers(SessionParsers p) +{ + func_parse = p.func_parse; + func_delete = p.func_delete; +} + + +void set_session(io::FD fd, std::unique_ptr<Session> sess) +{ + int f = fd.uncast_dammit(); + assert (0 <= f && f < FD_SETSIZE); + session[f] = std::move(sess); +} +Session *get_session(io::FD fd) +{ + int f = fd.uncast_dammit(); + if (0 <= f && f < FD_SETSIZE) + return session[f].get(); + return nullptr; +} +void reset_session(io::FD fd) +{ + int f = fd.uncast_dammit(); + assert (0 <= f && f < FD_SETSIZE); + session[f] = nullptr; +} +int get_fd_max() { return fd_max; } +IteratorPair<ValueIterator<io::FD, IncrFD>> iter_fds() +{ + return {io::FD::cast_dammit(0), io::FD::cast_dammit(fd_max)}; +} + +/// clean up by discarding handled bytes +inline +void RFIFOFLUSH(Session *s) +{ + really_memmove(&s->rdata[0], &s->rdata[s->rdata_pos], s->rdata_size - s->rdata_pos); + s->rdata_size -= s->rdata_pos; + s->rdata_pos = 0; +} + +/// how much room there is to read more data +inline +size_t RFIFOSPACE(Session *s) +{ + return s->max_rdata - s->rdata_size; +} + + +/// Read from socket to the queue +static +void recv_to_fifo(Session *s) +{ + ssize_t len = s->fd.read(&s->rdata[s->rdata_size], + RFIFOSPACE(s)); + + if (len > 0) + { + s->rdata_size += len; + s->connected = 1; + } + else + { + s->set_eof(); + } +} + +static +void send_from_fifo(Session *s) +{ + ssize_t len = s->fd.write(&s->wdata[0], s->wdata_size); + + if (len > 0) + { + s->wdata_size -= len; + if (s->wdata_size) + { + really_memmove(&s->wdata[0], &s->wdata[len], + s->wdata_size); + } + s->connected = 1; + } + else + { + s->set_eof(); + } +} + +static +void nothing_delete(Session *s) +{ + (void)s; +} + +static +void connect_client(Session *ls) +{ + struct sockaddr_in client_address; + socklen_t len = sizeof(client_address); + + io::FD fd = ls->fd.accept(reinterpret_cast<struct sockaddr *>(&client_address), &len); + if (fd == io::FD()) + { + perror("accept"); + return; + } + if (fd.uncast_dammit() >= SOFT_LIMIT) + { + FPRINTF(stderr, "softlimit reached, disconnecting : %d\n"_fmt, fd.uncast_dammit()); + fd.shutdown(SHUT_RDWR); + fd.close(); + return; + } + if (fd_max <= fd.uncast_dammit()) + { + fd_max = fd.uncast_dammit() + 1; + } + + const int yes = 1; + /// Allow to bind() again after the server restarts. + // Since the socket is still in the TIME_WAIT, there's a possibility + // that formerly lost packets might be delivered and confuse the server. + fd.setsockopt(SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes); + /// Send packets as soon as possible + /// even if the kernel thinks there is too little for it to be worth it! + /// Testing shows this is indeed a good idea. + fd.setsockopt(IPPROTO_TCP, TCP_NODELAY, &yes, sizeof yes); + + // Linux-ism: Set socket options to optimize for thin streams + // See http://lwn.net/Articles/308919/ and + // Documentation/networking/tcp-thin.txt .. Kernel 3.2+ +#ifdef TCP_THIN_LINEAR_TIMEOUTS + fd.setsockopt(IPPROTO_TCP, TCP_THIN_LINEAR_TIMEOUTS, &yes, sizeof yes); +#endif +#ifdef TCP_THIN_DUPACK + fd.setsockopt(IPPROTO_TCP, TCP_THIN_DUPACK, &yes, sizeof yes); +#endif + + readfds.set(fd); + + fd.fcntl(F_SETFL, O_NONBLOCK); + + set_session(fd, make_unique<Session>( + SessionIO{.func_recv= recv_to_fifo, .func_send= send_from_fifo}, + ls->for_inferior)); + Session *s = get_session(fd); + s->fd = fd; + s->rdata.new_(RFIFO_SIZE); + s->wdata.new_(WFIFO_SIZE); + s->max_rdata = RFIFO_SIZE; + s->max_wdata = WFIFO_SIZE; + s->client_ip = IP4Address(client_address.sin_addr); + s->created = TimeT::now(); + s->connected = 0; +} + +Session *make_listen_port(uint16_t port, SessionParsers inferior) +{ + struct sockaddr_in server_address; + io::FD fd = io::FD::socket(AF_INET, SOCK_STREAM, 0); + if (fd == io::FD()) + { + perror("socket"); + return nullptr; + } + if (fd_max <= fd.uncast_dammit()) + fd_max = fd.uncast_dammit() + 1; + + fd.fcntl(F_SETFL, O_NONBLOCK); + + const int yes = 1; + /// Allow to bind() again after the server restarts. + // Since the socket is still in the TIME_WAIT, there's a possibility + // that formerly lost packets might be delivered and confuse the server. + fd.setsockopt(SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes); + /// Send packets as soon as possible + /// even if the kernel thinks there is too little for it to be worth it! + // I'm not convinced this is a good idea; although in minimizes the + // latency for an individual write, it increases traffic in general. + fd.setsockopt(IPPROTO_TCP, TCP_NODELAY, &yes, sizeof yes); + + server_address.sin_family = AF_INET; + DIAG_PUSH(); + DIAG_I(old_style_cast); + DIAG_I(useless_cast); + server_address.sin_addr.s_addr = htonl(INADDR_ANY); + server_address.sin_port = htons(port); + DIAG_POP(); + + if (fd.bind(reinterpret_cast<struct sockaddr *>(&server_address), + sizeof(server_address)) == -1) + { + perror("bind"); + exit(1); + } + if (fd.listen(5) == -1) + { /* error */ + perror("listen"); + exit(1); + } + + readfds.set(fd); + + set_session(fd, make_unique<Session>( + SessionIO{.func_recv= connect_client, .func_send= nullptr}, + SessionParsers{.func_parse= nullptr, .func_delete= nothing_delete})); + Session *s = get_session(fd); + s->for_inferior = inferior; + s->fd = fd; + + s->created = TimeT::now(); + s->connected = 1; + + return s; +} + +Session *make_connection(IP4Address ip, uint16_t port, SessionParsers parsers) +{ + struct sockaddr_in server_address; + io::FD fd = io::FD::socket(AF_INET, SOCK_STREAM, 0); + if (fd == io::FD()) + { + perror("socket"); + return nullptr; + } + if (fd_max <= fd.uncast_dammit()) + fd_max = fd.uncast_dammit() + 1; + + const int yes = 1; + /// Allow to bind() again after the server restarts. + // Since the socket is still in the TIME_WAIT, there's a possibility + // that formerly lost packets might be delivered and confuse the server. + fd.setsockopt(SOL_SOCKET, SO_REUSEADDR, &yes, sizeof yes); + /// Send packets as soon as possible + /// even if the kernel thinks there is too little for it to be worth it! + // I'm not convinced this is a good idea; although in minimizes the + // latency for an individual write, it increases traffic in general. + fd.setsockopt(IPPROTO_TCP, TCP_NODELAY, &yes, sizeof yes); + + server_address.sin_family = AF_INET; + server_address.sin_addr = in_addr(ip); + DIAG_PUSH(); + DIAG_I(old_style_cast); + DIAG_I(useless_cast); + server_address.sin_port = htons(port); + DIAG_POP(); + + fd.fcntl(F_SETFL, O_NONBLOCK); + + /// Errors not caught - we must not block + /// Let the main select() loop detect when we know the state + fd.connect(reinterpret_cast<struct sockaddr *>(&server_address), + sizeof(struct sockaddr_in)); + + readfds.set(fd); + + set_session(fd, make_unique<Session>( + SessionIO{.func_recv= recv_to_fifo, .func_send= send_from_fifo}, + parsers)); + Session *s = get_session(fd); + s->fd = fd; + s->rdata.new_(RFIFO_SIZE); + s->wdata.new_(WFIFO_SIZE); + + s->max_rdata = RFIFO_SIZE; + s->max_wdata = WFIFO_SIZE; + s->created = TimeT::now(); + s->connected = 1; + + return s; +} + +void delete_session(Session *s) +{ + if (!s) + return; + // this needs to be before the fd_max-- + s->func_delete(s); + + io::FD fd = s->fd; + // If this was the highest fd, decrease it + // We could add a loop to decrement fd_max further for every null session, + // but this is cheap and good enough for the typical case + if (fd.uncast_dammit() == fd_max - 1) + fd_max--; + readfds.clr(fd); + { + s->rdata.delete_(); + s->wdata.delete_(); + s->session_data.reset(); + reset_session(fd); + } + + // just close() would try to keep sending buffers + fd.shutdown(SHUT_RDWR); + fd.close(); +} + +void realloc_fifo(Session *s, size_t rfifo_size, size_t wfifo_size) +{ + if (s->max_rdata != rfifo_size && s->rdata_size < rfifo_size) + { + s->rdata.resize(rfifo_size); + s->max_rdata = rfifo_size; + } + if (s->max_wdata != wfifo_size && s->wdata_size < wfifo_size) + { + s->wdata.resize(wfifo_size); + s->max_wdata = wfifo_size; + } +} + +void do_sendrecv(interval_t next_ms) +{ + bool any = false; + io::FD_Set rfd = readfds, wfd; + for (io::FD i : iter_fds()) + { + Session *s = get_session(i); + if (s) + { + any = true; + if (s->wdata_size) + wfd.set(i); + } + } + if (!any) + { + if (!has_timers()) + { + PRINTF("Shutting down - nothing to do\n"_fmt); + // TODO hoist this + runflag = false; + } + return; + } + struct timeval timeout; + { + std::chrono::seconds next_s = std::chrono::duration_cast<std::chrono::seconds>(next_ms); + std::chrono::microseconds next_us = next_ms - next_s; + timeout.tv_sec = next_s.count(); + timeout.tv_usec = next_us.count(); + } + if (io::FD_Set::select(fd_max, &rfd, &wfd, nullptr, &timeout) <= 0) + return; + for (io::FD i : iter_fds()) + { + Session *s = get_session(i); + if (!s) + continue; + if (wfd.isset(i) && !s->eof) + { + if (s->func_send) + //send_from_fifo(i); + s->func_send(s); + } + if (rfd.isset(i) && !s->eof) + { + if (s->func_recv) + //recv_to_fifo(i); + //or connect_client(i); + s->func_recv(s); + } + } +} + +void do_parsepacket(void) +{ + for (io::FD i : iter_fds()) + { + Session *s = get_session(i); + if (!s) + continue; + if (!s->connected + && static_cast<time_t>(TimeT::now()) - static_cast<time_t>(s->created) > CONNECT_TIMEOUT) + { + PRINTF("Session #%d timed out\n"_fmt, s); + s->set_eof(); + } + if (s->rdata_size && !s->eof && s->func_parse) + { + s->func_parse(s); + /// some func_parse may call delete_session + // (that's kind of evil) + s = get_session(i); + if (!s) + continue; + } + if (s->eof) + { + delete_session(s); + continue; + } + /// Reclaim buffer space for what was read + RFIFOFLUSH(s); + } +} +} // namespace tmwa diff --git a/src/net/socket.hpp b/src/net/socket.hpp new file mode 100644 index 0000000..576ef85 --- /dev/null +++ b/src/net/socket.hpp @@ -0,0 +1,177 @@ +#pragma once +// socket.hpp - Network event system. +// +// Copyright © ????-2004 Athena Dev Teams +// Copyright © 2004-2011 The Mana World Development Team +// Copyright © 2011-2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "fwd.hpp" + +#include <algorithm> + +#include <sys/select.h> + +#include <memory> + +#include "../compat/iter.hpp" +#include "../compat/rawmem.hpp" +#include "../compat/time_t.hpp" + +#include "../strings/astring.hpp" +#include "../strings/vstring.hpp" +#include "../strings/xstring.hpp" + +#include "../generic/dumb_ptr.hpp" + +#include "../io/fd.hpp" + +#include "ip.hpp" +#include "timer.t.hpp" + + +namespace tmwa +{ +struct SessionData +{ +}; +struct SessionDeleter +{ + // defined per-server + void operator()(SessionData *sd); +}; + +struct SessionIO +{ + void (*func_recv)(Session *); + void (*func_send)(Session *); +}; + +struct SessionParsers +{ + void (*func_parse)(Session *); + void (*func_delete)(Session *); +}; + +struct Session +{ + Session(SessionIO, SessionParsers); + Session(Session&&) = delete; + Session& operator = (Session&&) = delete; + + void set_io(SessionIO); + void set_parsers(SessionParsers); + + /// Checks whether a newly-connected socket actually does anything + TimeT created; + bool connected; + +private: + /// Flag needed since structure must be freed in a server-dependent manner + bool eof; +public: + void set_eof() { eof = true; } + + /// Currently used by clif_setwaitclose + Timer timed_close; + + /// Since this is a single-threaded application, it can't block + /// These are the read/write queues + dumb_ptr<uint8_t[]> rdata, wdata; + size_t max_rdata, max_wdata; + /// How much is actually in the queue + size_t rdata_size, wdata_size; + /// How much has already been read from the queue + /// Note that there is no need for a wdata_pos + size_t rdata_pos; + + IP4Address client_ip; + +private: + /// Send or recieve + /// Only called when select() indicates the socket is ready + /// If, after that, nothing is read, it sets eof + // These could probably be hard-coded with a little work + void (*func_recv)(Session *); + void (*func_send)(Session *); + /// This is the important one + /// Set to different functions depending on whether the connection + /// is a player or a server/ladmin + void (*func_parse)(Session *); + /// Cleanup function since we're not fully RAII yet + void (*func_delete)(Session *); + +public: + // this really ought to be part of session_data, once that gets sane + SessionParsers for_inferior; + + /// Server-specific data type + // (this really should include the deleter, but ...) + std::unique_ptr<SessionData, SessionDeleter> session_data; + + io::FD fd; + + friend void do_sendrecv(interval_t next); + friend void do_parsepacket(void); + friend void delete_session(Session *); +}; + +inline +int convert_for_printf(Session *s) +{ + return s->fd.uncast_dammit(); +} + +// save file descriptors for important stuff +constexpr int SOFT_LIMIT = FD_SETSIZE - 50; + +// socket timeout to establish a full connection in seconds +constexpr int CONNECT_TIMEOUT = 15; + + +void set_session(io::FD fd, std::unique_ptr<Session> sess); +Session *get_session(io::FD fd); +void reset_session(io::FD fd); +int get_fd_max(); + +class IncrFD +{ +public: + static + io::FD inced(io::FD v) + { + return io::FD::cast_dammit(v.uncast_dammit() + 1); + } +}; +IteratorPair<ValueIterator<io::FD, IncrFD>> iter_fds(); + + +/// open a socket, bind, and listen. Return an fd, or -1 if socket() fails, +/// but exit if bind() or listen() fails +Session *make_listen_port(uint16_t port, SessionParsers inferior); +/// Connect to an address, return a connected socket or -1 +// FIXME - this is IPv4 only! +Session *make_connection(IP4Address ip, uint16_t port, SessionParsers); +/// free() the structure and close() the fd +void delete_session(Session *); +/// Make a the internal queues bigger +void realloc_fifo(Session *s, size_t rfifo_size, size_t wfifo_size); +/// Update all sockets that can be read/written from the queues +void do_sendrecv(interval_t next); +/// Call the parser function for every socket that has read data +void do_parsepacket(void); +} // namespace tmwa diff --git a/src/net/timer.cpp b/src/net/timer.cpp new file mode 100644 index 0000000..6a22616 --- /dev/null +++ b/src/net/timer.cpp @@ -0,0 +1,221 @@ +#include "timer.hpp" +// timer.cpp - Future event scheduler. +// +// Copyright © ????-2004 Athena Dev Teams +// Copyright © 2004-2011 The Mana World Development Team +// Copyright © 2011-2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include <sys/stat.h> +#include <sys/time.h> + +#include <cassert> + +#include <algorithm> +#include <queue> + +#include "../strings/zstring.hpp" + +#include "../poison.hpp" + + +namespace tmwa +{ +struct TimerData +{ + /// This will be reset on call, to avoid problems. + Timer *owner; + + /// When it will be triggered + tick_t tick; + /// What will be done + timer_func func; + /// Repeat rate - 0 for oneshot + interval_t interval; + + TimerData(Timer *o, tick_t t, timer_func f, interval_t i) + : owner(o) + , tick(t) + , func(std::move(f)) + , interval(i) + {} +}; + +struct TimerCompare +{ + /// implement "less than" + bool operator() (dumb_ptr<TimerData> l, dumb_ptr<TimerData> r) + { + // C++ provides a max-heap, but we want + // the smallest tick to be the head (a min-heap). + return l->tick > r->tick; + } +}; + +static +std::priority_queue<dumb_ptr<TimerData>, std::vector<dumb_ptr<TimerData>>, TimerCompare> timer_heap; + + +tick_t gettick_cache; + +tick_t milli_clock::now(void) noexcept +{ + struct timeval tval; + // BUG: This will cause strange behavior if the system clock is changed! + // it should be reimplemented in terms of clock_gettime(CLOCK_MONOTONIC, ) + gettimeofday(&tval, nullptr); + return gettick_cache = tick_t(std::chrono::seconds(tval.tv_sec) + + std::chrono::duration_cast<std::chrono::milliseconds>( + std::chrono::microseconds(tval.tv_usec))); +} + +static +void do_nothing(TimerData *, tick_t) +{ +} + +void Timer::cancel() +{ + if (!td) + return; + + assert (this == td->owner); + td->owner = nullptr; + td->func = do_nothing; + td->interval = interval_t::zero(); + td = nullptr; +} + +void Timer::detach() +{ + assert (this == td->owner); + td->owner = nullptr; + td = nullptr; +} + +static +void push_timer_heap(dumb_ptr<TimerData> td) +{ + timer_heap.push(td); +} + +static +dumb_ptr<TimerData> top_timer_heap(void) +{ + if (timer_heap.empty()) + return dumb_ptr<TimerData>(); + return timer_heap.top(); +} + +static +void pop_timer_heap(void) +{ + timer_heap.pop(); +} + +Timer::Timer(tick_t tick, timer_func func, interval_t interval) +: td(dumb_ptr<TimerData>::make(this, tick, std::move(func), interval)) +{ + assert (interval >= interval_t::zero()); + + push_timer_heap(td); +} + +Timer::Timer(Timer&& t) +: td(t.td) +{ + t.td = nullptr; + if (td) + { + assert (td->owner == &t); + td->owner = this; + } +} + +Timer& Timer::operator = (Timer&& t) +{ + std::swap(td, t.td); + if (td) + { + assert (td->owner == &t); + td->owner = this; + } + if (t.td) + { + assert (t.td->owner == this); + t.td->owner = &t; + } + return *this; +} + +interval_t do_timer(tick_t tick) +{ + /// Number of milliseconds until it calls this again + // this says to wait 1 sec if all timers get popped + interval_t nextmin = 1_s; + + while (dumb_ptr<TimerData> td = top_timer_heap()) + { + // while the heap is not empty and + if (td->tick > tick) + { + /// Return the time until the next timer needs to goes off + nextmin = td->tick - tick; + break; + } + pop_timer_heap(); + + // Prevent destroying the object we're in. + // Note: this would be surprising in an interval timer, + // but all interval timers do an immediate explicit detach(). + if (td->owner) + td->owner->detach(); + // If we are too far past the requested tick, call with + // the current tick instead to fix reregistration problems + if (td->tick + 1_s < tick) + td->func(td.operator->(), tick); + else + td->func(td.operator->(), td->tick); + + if (td->interval == interval_t::zero()) + { + td.delete_(); + continue; + } + if (td->tick + 1_s < tick) + td->tick = tick + td->interval; + else + td->tick += td->interval; + push_timer_heap(td); + } + + return std::max(nextmin, 10_ms); +} + +tick_t file_modified(ZString name) +{ + struct stat buf; + if (stat(name.c_str(), &buf)) + return tick_t(); + return tick_t(std::chrono::seconds(buf.st_mtime)); +} + +bool has_timers() +{ + return !timer_heap.empty(); +} +} // namespace tmwa diff --git a/src/net/timer.hpp b/src/net/timer.hpp new file mode 100644 index 0000000..338e339 --- /dev/null +++ b/src/net/timer.hpp @@ -0,0 +1,51 @@ +#pragma once +// timer.hpp - Future event scheduler. +// +// Copyright © ????-2004 Athena Dev Teams +// Copyright © 2004-2011 The Mana World Development Team +// Copyright © 2011-2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "timer.t.hpp" + +#include "fwd.hpp" + +#include "../strings/fwd.hpp" + + +namespace tmwa +{ +// updated automatically when using milli_clock::now() +// which is done only by core.cpp +extern tick_t gettick_cache; + +inline +tick_t gettick(void) +{ + return gettick_cache; +} + +/// Do all timers scheduled before tick, and return the number of +/// milliseconds until the next timer happens +interval_t do_timer(tick_t tick); + +/// Stat a file, and return its modification time, truncated to seconds. +tick_t file_modified(ZString name); + +/// Check if there are any events at all scheduled. +bool has_timers(); +} // namespace tmwa diff --git a/src/net/timer.t.hpp b/src/net/timer.t.hpp new file mode 100644 index 0000000..090e62a --- /dev/null +++ b/src/net/timer.t.hpp @@ -0,0 +1,164 @@ +#pragma once +// timer.t.hpp - Future event scheduler. +// +// Copyright © ????-2004 Athena Dev Teams +// Copyright © 2004-2011 The Mana World Development Team +// Copyright © 2011-2014 Ben Longbons <b.r.longbons@gmail.com> +// +// This file is part of The Mana World (Athena server) +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see <http://www.gnu.org/licenses/>. + +#include "fwd.hpp" + +#include <cstdlib> + +#include <chrono> +#include <functional> + +#include "../ints/little.hpp" + +#include "../generic/dumb_ptr.hpp" + + +namespace tmwa +{ +constexpr +std::chrono::nanoseconds operator "" _ns(unsigned long long ns) +{ return std::chrono::nanoseconds(ns); } +constexpr +std::chrono::microseconds operator "" _us(unsigned long long us) +{ return std::chrono::microseconds(us); } +constexpr +std::chrono::milliseconds operator "" _ms(unsigned long long ms) +{ return std::chrono::milliseconds(ms); } +constexpr +std::chrono::seconds operator "" _s(unsigned long long s) +{ return std::chrono::seconds(s); } +constexpr +std::chrono::minutes operator "" _min(unsigned long long min) +{ return std::chrono::minutes(min); } +constexpr +std::chrono::hours operator "" _h(unsigned long long h) +{ return std::chrono::hours(h); } +constexpr +std::chrono::duration<int, std::ratio<60*60*24>> operator "" _d(unsigned long long d) +{ return std::chrono::duration<int, std::ratio<60*60*24>>(d); } + +/// An implementation of the C++ "clock" concept, exposing +/// durations in milliseconds. +class milli_clock +{ +public: + typedef std::chrono::milliseconds duration; + typedef duration::rep rep; + typedef duration::period period; + typedef std::chrono::time_point<milli_clock, duration> time_point; + static const bool is_steady = true; // assumed - not necessarily true + + static time_point now() noexcept; +}; + +/// A point in time. +typedef milli_clock::time_point tick_t; +/// The difference between two points in time. +typedef milli_clock::duration interval_t; +/// (to get additional arguments, use std::bind or a lambda). +typedef std::function<void (TimerData *, tick_t)> timer_func; + +// 49.7 day problem +inline __attribute__((warn_unused_result)) +bool native_to_network(Little32 *net, tick_t nat) +{ + auto tmp = nat.time_since_epoch().count(); + return native_to_network(net, static_cast<uint32_t>(tmp)); +} + +inline __attribute__((warn_unused_result)) +bool network_to_native(tick_t *nat, Little32 net) +{ + (void)nat; + (void)net; + abort(); +} + +inline __attribute__((warn_unused_result)) +bool native_to_network(Little32 *net, interval_t nat) +{ + auto tmp = nat.count(); + return native_to_network(net, static_cast<uint32_t>(tmp)); +} + +inline __attribute__((warn_unused_result)) +bool network_to_native(interval_t *nat, Little32 net) +{ + uint32_t tmp; + bool rv = network_to_native(&tmp, net); + *nat = interval_t(tmp); + return rv; +} + +inline __attribute__((warn_unused_result)) +bool native_to_network(Little16 *net, interval_t nat) +{ + auto tmp = nat.count(); + return native_to_network(net, static_cast<uint16_t>(tmp)); +} + +inline __attribute__((warn_unused_result)) +bool network_to_native(interval_t *nat, Little16 net) +{ + uint16_t tmp; + bool rv = network_to_native(&tmp, net); + *nat = interval_t(tmp); + return rv; +} + + +class Timer +{ + friend struct TimerData; + dumb_ptr<TimerData> td; + + Timer(const Timer&) = delete; + Timer& operator = (const Timer&) = delete; +public: + /// Don't own anything yet. + Timer() = default; + /// Schedule a timer for the given tick. + /// If you do not wish to keep track of it, call disconnect(). + /// Otherwise, you may cancel() or replace (operator =) it later. + /// + /// If the interval argument is given, the timer will reschedule + /// itself again forever. Otherwise, it will disconnect() itself + /// just BEFORE it is called. + Timer(tick_t tick, timer_func func, interval_t interval=interval_t::zero()); + + Timer(Timer&& t); + Timer& operator = (Timer&& t); + ~Timer() { cancel(); } + + /// Cancel the delivery of this timer's function, and make it falsy. + /// Implementation note: this doesn't actually remove it, just sets + /// the functor to do_nothing, and waits for the tick before removing. + void cancel(); + /// Make it falsy without cancelling the timer, + void detach(); + + /// Check if there is a timer connected. + explicit operator bool() { return bool(td); } + /// Check if there is no connected timer. + bool operator !() { return !td; } +}; +} // namespace tmwa |