From 77f7206d91a12abd4effd5c20188653e83faa54b Mon Sep 17 00:00:00 2001 From: Thorbjørn Lindeijer Date: Fri, 19 Aug 2022 15:04:03 +0200 Subject: Fixed possible leak in AccountHandler::handleUnregisterMessage Fixed by changing account instances to be managed by std::unique_ptr, so we don't forget to delete them somewhere, like in that function as well as during shutdown in AccountHandler. --- src/account-server/accountclient.h | 6 +-- src/account-server/accounthandler.cpp | 67 +++++++++++++++++---------------- src/account-server/character.cpp | 1 + src/account-server/storage.cpp | 70 +++++++++++++++++------------------ src/account-server/storage.h | 21 ++++++----- 5 files changed, 85 insertions(+), 80 deletions(-) diff --git a/src/account-server/accountclient.h b/src/account-server/accountclient.h index afb2ef3e..3973d60b 100644 --- a/src/account-server/accountclient.h +++ b/src/account-server/accountclient.h @@ -45,7 +45,7 @@ class AccountClient : public NetComputer public: AccountClient(ENetPeer *peer); - void setAccount(Account *acc); + void setAccount(std::unique_ptr acc); void unsetAccount(); Account *getAccount() const; @@ -59,9 +59,9 @@ class AccountClient : public NetComputer /** * Set the account associated with the connection. */ -inline void AccountClient::setAccount(Account *acc) +inline void AccountClient::setAccount(std::unique_ptr acc) { - mAccount.reset(acc); + mAccount = std::move(acc); } /** diff --git a/src/account-server/accounthandler.cpp b/src/account-server/accounthandler.cpp index 8f871455..856f07d0 100644 --- a/src/account-server/accounthandler.cpp +++ b/src/account-server/accounthandler.cpp @@ -102,7 +102,7 @@ private: /** List of all accounts which requested a random seed, but are not logged * yet. This list will be regularly remove (after timeout) old accounts */ - std::list mPendingAccounts; + std::list> mPendingAccounts; /** List of attributes that the client can send at account creation. */ std::vector mModifiableAttributes; @@ -308,13 +308,14 @@ static void sendFullCharacterData(AccountClient *client, static std::string getRandomString(int length) { - char s[length]; + std::string s; + s.resize(length); // No need to care about zeros. They can be handled. // But care for endianness for (int i = 0; i < length; ++i) s[i] = (char)rand(); - return std::string(s, length); + return s; } void AccountHandler::handleLoginRandTriggerMessage(AccountClient &client, MessageIn &msg) @@ -322,10 +323,10 @@ void AccountHandler::handleLoginRandTriggerMessage(AccountClient &client, Messag std::string salt = getRandomString(4); std::string username = msg.readString(); - if (Account *acc = storage->getAccount(username)) + if (auto acc = storage->getAccount(username)) { acc->setRandomSalt(salt); - mPendingAccounts.push_back(acc); + mPendingAccounts.push_back(std::move(acc)); } MessageOut reply(APMSG_LOGIN_RNDTRGR_RESPONSE); reply.writeString(salt); @@ -389,17 +390,22 @@ void AccountHandler::handleLoginMessage(AccountClient &client, MessageIn &msg) } // Check if the account exists - Account *acc = nullptr; - for (Account *account : mPendingAccounts) - if (account->getName() == username) - acc = account; - mPendingAccounts.remove(acc); + auto accIt = std::find_if(mPendingAccounts.begin(), + mPendingAccounts.end(), + [&] (const std::unique_ptr &acc) { + return acc->getName() == username; + }); + + std::unique_ptr acc; + if (accIt != mPendingAccounts.end()) { + acc = std::move(*accIt); + mPendingAccounts.erase(accIt); + } if (!acc || sha256(acc->getPassword() + acc->getRandomSalt()) != password) { reply.writeInt8(ERRMSG_INVALID_ARGUMENT); client.send(reply); - delete acc; return; } @@ -407,7 +413,6 @@ void AccountHandler::handleLoginMessage(AccountClient &client, MessageIn &msg) { reply.writeInt8(LOGIN_BANNED); client.send(reply); - delete acc; return; } @@ -417,16 +422,12 @@ void AccountHandler::handleLoginMessage(AccountClient &client, MessageIn &msg) time_t login; time(&login); acc->setLastLogin(login); - storage->updateLastLogin(acc); - - // Associate account with connection. - client.setAccount(acc); - client.status = CLIENT_CONNECTED; + storage->updateLastLogin(*acc); reply.writeInt8(ERRMSG_OK); addServerInfo(&reply); - Characters &chars = acc->getCharacters(); + const Characters &chars = acc->getCharacters(); if (client.version < 10) { client.send(reply); @@ -436,6 +437,10 @@ void AccountHandler::handleLoginMessage(AccountClient &client, MessageIn &msg) sendCharacterData(reply, charIt.second); client.send(reply); } + + // Associate account with connection. + client.setAccount(std::move(acc)); + client.status = CLIENT_CONNECTED; } void AccountHandler::handleLogoutMessage(AccountClient &client) @@ -529,7 +534,7 @@ void AccountHandler::handleRegisterMessage(AccountClient &client, } else { - Account *acc = new Account; + std::unique_ptr acc { new Account }; acc->setName(username); acc->setPassword(sha256(password)); // We hash email server-side for additional privacy @@ -544,12 +549,12 @@ void AccountHandler::handleRegisterMessage(AccountClient &client, acc->setRegistrationDate(regdate); acc->setLastLogin(regdate); - storage->addAccount(acc); + storage->addAccount(*acc); reply.writeInt8(ERRMSG_OK); addServerInfo(&reply); // Associate account with connection - client.setAccount(acc); + client.setAccount(std::move(acc)); client.status = CLIENT_CONNECTED; } @@ -581,20 +586,19 @@ void AccountHandler::handleUnregisterMessage(AccountClient &client, } // See whether the account exists - Account *acc = storage->getAccount(username); + auto acc = storage->getAccount(username); if (!acc || acc->getPassword() != sha256(password)) { reply.writeInt8(ERRMSG_INVALID_ARGUMENT); client.send(reply); - delete acc; return; } // Delete account and associated characters LOG_INFO("Unregistered \"" << username << "\", AccountID: " << acc->getID()); - storage->delAccount(acc); + storage->delAccount(*acc); reply.writeInt8(ERRMSG_OK); client.send(reply); @@ -654,7 +658,7 @@ void AccountHandler::handleEmailChangeMessage(AccountClient &client, { acc->setEmail(emailHash); // Keep the database up to date otherwise we will go out of sync - storage->flush(acc); + storage->flush(*acc); reply.writeInt8(ERRMSG_OK); } client.send(reply); @@ -685,7 +689,7 @@ void AccountHandler::handlePasswordChangeMessage(AccountClient &client, { acc->setPassword(newPassword); // Keep the database up to date otherwise we will go out of sync - storage->flush(acc); + storage->flush(*acc); reply.writeInt8(ERRMSG_OK); } @@ -822,7 +826,7 @@ void AccountHandler::handleCharacterCreateMessage(AccountClient &client, LOG_INFO("Character " << name << " was created for " << acc->getName() << "'s account."); - storage->flush(acc); // flush changes + storage->flush(*acc); // flush changes // log transaction Transaction trans; @@ -938,7 +942,7 @@ void AccountHandler::handleCharacterDeleteMessage(AccountClient &client, return; } - std::string characterName = chars[slot]->getName(); + const std::string &characterName = chars[slot]->getName(); LOG_INFO("Character deleted:" << characterName); // Log transaction @@ -950,7 +954,7 @@ void AccountHandler::handleCharacterDeleteMessage(AccountClient &client, storage->addTransaction(trans); acc->delCharacter(slot); - storage->flush(acc); + storage->flush(*acc); reply.writeInt8(ERRMSG_OK); client.send(reply); @@ -976,15 +980,14 @@ void AccountHandler::tokenMatched(AccountClient *client, int accountID) MessageOut reply(APMSG_RECONNECT_RESPONSE); // Associate account with connection. - Account *acc = storage->getAccount(accountID); - client->setAccount(acc); + client->setAccount(storage->getAccount(accountID)); client->status = CLIENT_CONNECTED; reply.writeInt8(ERRMSG_OK); client->send(reply); // Return information about available characters - Characters &chars = acc->getCharacters(); + const Characters &chars = client->getAccount()->getCharacters(); // Send characters list sendFullCharacterData(client, chars); diff --git a/src/account-server/character.cpp b/src/account-server/character.cpp index 35a8b079..80b28b55 100644 --- a/src/account-server/character.cpp +++ b/src/account-server/character.cpp @@ -119,6 +119,7 @@ void CharacterData::serialize(MessageOut &msg) msg.writeInt8(0); // not equipped } } + void CharacterData::deserialize(MessageIn &msg) { // general character properties diff --git a/src/account-server/storage.cpp b/src/account-server/storage.cpp index fd9160d3..ea2b0eff 100644 --- a/src/account-server/storage.cpp +++ b/src/account-server/storage.cpp @@ -162,7 +162,7 @@ void Storage::close() mDb->disconnect(); } -Account *Storage::getAccountBySQL() +std::unique_ptr Storage::getAccountBySQL() { try { @@ -178,7 +178,7 @@ Account *Storage::getAccountBySQL() // Create an Account instance // and initialize it with information about the user. - Account *account = new Account(id); + std::unique_ptr account { new Account(id) }; account->setName(accountInfo(0, 1)); account->setPassword(accountInfo(0, 2)); account->setEmail(accountInfo(0, 3)); @@ -223,7 +223,7 @@ Account *Storage::getAccountBySQL() for (int k = 0; k < size; ++k) { if (CharacterData *ptr = - getCharacter(characterIDs[k], account)) + getCharacter(characterIDs[k], account.get())) { characters[ptr->getCharacterSlot()] = ptr; } @@ -315,7 +315,7 @@ void Storage::fixCharactersSlot(int accountId) } } -Account *Storage::getAccount(const std::string &userName) +std::unique_ptr Storage::getAccount(const std::string &userName) { std::ostringstream sql; sql << "SELECT * FROM " << ACCOUNTS_TBL_NAME << " WHERE username = ?"; @@ -327,7 +327,7 @@ Account *Storage::getAccount(const std::string &userName) return 0; } -Account *Storage::getAccount(int accountID) +std::unique_ptr Storage::getAccount(int accountID) { std::ostringstream sql; sql << "SELECT * FROM " << ACCOUNTS_TBL_NAME << " WHERE id = ?"; @@ -892,9 +892,9 @@ bool Storage::updateCharacter(CharacterData *character) return true; } -void Storage::addAccount(Account *account) +void Storage::addAccount(Account &account) { - assert(account->getCharacters().size() == 0); + assert(account.getCharacters().size() == 0); using namespace dal; @@ -906,18 +906,18 @@ void Storage::addAccount(Account *account) << " (username, password, email, level, " << "banned, registration, lastlogin)" << " VALUES (?, ?, ?, " - << account->getLevel() << ", 0, " - << account->getRegistrationDate() << ", " - << account->getLastLogin() << ");"; + << account.getLevel() << ", 0, " + << account.getRegistrationDate() << ", " + << account.getLastLogin() << ");"; if (mDb->prepareSql(sql.str())) { - mDb->bindValue(1, account->getName()); - mDb->bindValue(2, account->getPassword()); - mDb->bindValue(3, account->getEmail()); + mDb->bindValue(1, account.getName()); + mDb->bindValue(2, account.getPassword()); + mDb->bindValue(3, account.getEmail()); mDb->processSql(); - account->setID(mDb->getLastId()); + account.setID(mDb->getLastId()); } else { @@ -931,9 +931,9 @@ void Storage::addAccount(Account *account) } } -void Storage::flush(Account *account) +void Storage::flush(const Account &account) { - assert(account->getID() >= 0); + assert(account.getID() >= 0); using namespace dal; @@ -950,12 +950,12 @@ void Storage::flush(Account *account) if (mDb->prepareSql(sqlUpdateAccountTable.str())) { - mDb->bindValue(1, account->getName()); - mDb->bindValue(2, account->getPassword()); - mDb->bindValue(3, account->getEmail()); - mDb->bindValue(4, account->getLevel()); - mDb->bindValue(5, account->getLastLogin()); - mDb->bindValue(6, account->getID()); + mDb->bindValue(1, account.getName()); + mDb->bindValue(2, account.getPassword()); + mDb->bindValue(3, account.getEmail()); + mDb->bindValue(4, account.getLevel()); + mDb->bindValue(5, account.getLastLogin()); + mDb->bindValue(6, account.getID()); mDb->processSql(); } @@ -966,7 +966,7 @@ void Storage::flush(Account *account) } // Get the list of characters that belong to this account. - Characters &characters = account->getCharacters(); + const Characters &characters = account.getCharacters(); // Insert or update the characters. for (Characters::const_iterator it = characters.begin(), @@ -988,12 +988,12 @@ void Storage::flush(Account *account) << " (user_id, name, gender, hair_style, hair_color," << " char_pts, correct_pts," << " x, y, map_id, slot) values (" - << account->getID() << ", ?, " + << account.getID() << ", ?, " << character->getGender() << ", " - << (int)character->getHairStyle() << ", " - << (int)character->getHairColor() << ", " - << (int)character->getAttributePoints() << ", " - << (int)character->getCorrectionPoints() << ", " + << character->getHairStyle() << ", " + << character->getHairColor() << ", " + << character->getAttributePoints() << ", " + << character->getCorrectionPoints() << ", " << character->getPosition().x << ", " << character->getPosition().y << ", " << character->getMapId() << ", " @@ -1029,7 +1029,7 @@ void Storage::flush(Account *account) std::ostringstream sqlSelectNameIdCharactersTable; sqlSelectNameIdCharactersTable << "select name, id from " << CHARACTERS_TBL_NAME - << " where user_id = '" << account->getID() << "';"; + << " where user_id = '" << account.getID() << "';"; const RecordSet& charInMemInfo = mDb->execSql(sqlSelectNameIdCharactersTable.str()); @@ -1069,7 +1069,7 @@ void Storage::flush(Account *account) } } -void Storage::delAccount(Account *account) +void Storage::delAccount(Account &account) { // Sync the account info into the database. flush(account); @@ -1079,11 +1079,11 @@ void Storage::delAccount(Account *account) // Delete the account. std::ostringstream sql; sql << "delete from " << ACCOUNTS_TBL_NAME - << " where id = '" << account->getID() << "';"; + << " where id = '" << account.getID() << "';"; mDb->execSql(sql.str()); // Remove the account's characters. - account->setCharacters(Characters()); + account.setCharacters(Characters()); } catch (const std::exception &e) { @@ -1091,14 +1091,14 @@ void Storage::delAccount(Account *account) } } -void Storage::updateLastLogin(const Account *account) +void Storage::updateLastLogin(const Account &account) { try { std::ostringstream sql; sql << "UPDATE " << ACCOUNTS_TBL_NAME - << " SET lastlogin = '" << account->getLastLogin() << "'" - << " WHERE id = '" << account->getID() << "';"; + << " SET lastlogin = '" << account.getLastLogin() << "'" + << " WHERE id = '" << account.getID() << "';"; mDb->execSql(sql.str()); } catch (const dal::DbSqlQueryExecFailure &e) diff --git a/src/account-server/storage.h b/src/account-server/storage.h index d635fb9b..ca83e6f1 100644 --- a/src/account-server/storage.h +++ b/src/account-server/storage.h @@ -23,6 +23,7 @@ #include #include +#include #include #include "dal/dataprovider.h" @@ -47,6 +48,9 @@ class Storage Storage(); ~Storage(); + Storage(const Storage &rhs) = delete; + Storage &operator=(const Storage &rhs) = delete; + /** * Connect to the database and initialize it if necessary. */ @@ -64,7 +68,7 @@ class Storage * * @return the account associated to the user name. */ - Account *getAccount(const std::string &userName); + std::unique_ptr getAccount(const std::string &userName); /** * Get an account by Id. @@ -73,7 +77,7 @@ class Storage * * @return the account associated with the Id. */ - Account *getAccount(int accountId); + std::unique_ptr getAccount(int accountId); /** * Gets a character by database Id. @@ -108,21 +112,21 @@ class Storage * * @param account the new account. */ - void addAccount(Account *account); + void addAccount(Account &account); /** * Delete an account and its associated data from the database. * * @param account the account to delete. */ - void delAccount(Account *account); + void delAccount(Account &account); /** * Update the date and time of the last login. * * @param account the account that recently logged in. */ - void updateLastLogin(const Account *account); + void updateLastLogin(const Account &account); /** * Write a modification message about Character points to the database. @@ -314,7 +318,7 @@ class Storage * * @param Account object to update. */ - void flush(Account *); + void flush(const Account &); /** * Gets the value of a quest variable. @@ -440,15 +444,12 @@ class Storage { return mDb; } private: - Storage(const Storage &rhs) = delete; - Storage &operator=(const Storage &rhs) = delete; - /** * Gets an account from a prepared SQL statement * * @return the account found */ - Account *getAccountBySQL(); + std::unique_ptr getAccountBySQL(); /** * Gets a character from a prepared SQL statement -- cgit v1.2.3-70-g09d2