diff options
author | David Athay <ko2fan@gmail.com> | 2009-07-10 10:05:47 +0100 |
---|---|---|
committer | David Athay <ko2fan@gmail.com> | 2009-07-10 10:05:47 +0100 |
commit | ea21b3bf96d116964398273f1b096f61462b35dd (patch) | |
tree | 02ace9e83bdd3051855f9578831d29350f2d5336 /src | |
parent | e0884d0ac3dae67e2599c687d26823600f8c81b7 (diff) | |
download | manaserv-ea21b3bf96d116964398273f1b096f61462b35dd.tar.gz manaserv-ea21b3bf96d116964398273f1b096f61462b35dd.tar.bz2 manaserv-ea21b3bf96d116964398273f1b096f61462b35dd.tar.xz manaserv-ea21b3bf96d116964398273f1b096f61462b35dd.zip |
Changed database to using prepared statements, to stop SQL injection attacks
Diffstat (limited to 'src')
-rw-r--r-- | src/account-server/dalstorage.cpp | 222 | ||||
-rw-r--r-- | src/account-server/dalstorage.hpp | 13 | ||||
-rw-r--r-- | src/dal/dataprovider.cpp | 3 | ||||
-rw-r--r-- | src/dal/dataprovider.h | 35 | ||||
-rw-r--r-- | src/dal/sqlitedataprovider.cpp | 61 | ||||
-rw-r--r-- | src/dal/sqlitedataprovider.h | 36 |
6 files changed, 298 insertions, 72 deletions
diff --git a/src/account-server/dalstorage.cpp b/src/account-server/dalstorage.cpp index 615e20f2..6d5cf3f4 100644 --- a/src/account-server/dalstorage.cpp +++ b/src/account-server/dalstorage.cpp @@ -117,10 +117,11 @@ void DALStorage::close() mDb->disconnect(); } -Account *DALStorage::getAccountBySQL(const std::string &query) +Account *DALStorage::getAccountBySQL() { try { - const dal::RecordSet &accountInfo = mDb->execSql(query); +// const dal::RecordSet &accountInfo = mDb->execSql(query); + const dal::RecordSet &accountInfo = mDb->processSql(); // if the account is not even in the database then // we have no choice but to return nothing. @@ -204,8 +205,13 @@ Account *DALStorage::getAccountBySQL(const std::string &query) Account *DALStorage::getAccount(const std::string &userName) { std::ostringstream sql; - sql << "select * from " << ACCOUNTS_TBL_NAME << " where username = \"" << userName << "\";"; - return getAccountBySQL(sql.str()); +// sql << "select * from " << ACCOUNTS_TBL_NAME << " where username = \"" << userName << "\";"; + sql << "SELECT * FROM " << ACCOUNTS_TBL_NAME << " WHERE username = ?"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, userName); + } + return getAccountBySQL(); } /** @@ -214,11 +220,16 @@ Account *DALStorage::getAccount(const std::string &userName) Account *DALStorage::getAccount(int accountID) { std::ostringstream sql; - sql << "select * from " << ACCOUNTS_TBL_NAME << " where id = '" << accountID << "';"; - return getAccountBySQL(sql.str()); +// sql << "select * from " << ACCOUNTS_TBL_NAME << " where id = '" << accountID << "';"; + sql << "SELECT * FROM " << ACCOUNTS_TBL_NAME << " WHERE id = ?"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindInteger(1, accountID); + } + return getAccountBySQL(); } -Character *DALStorage::getCharacterBySQL(const std::string &query, Account *owner) +Character *DALStorage::getCharacterBySQL(Account *owner) { Character *character; @@ -227,7 +238,8 @@ Character *DALStorage::getCharacterBySQL(const std::string &query, Account *owne string_to< unsigned > toUint; try { - const dal::RecordSet &charInfo = mDb->execSql(query); +// const dal::RecordSet &charInfo = mDb->execSql(query); + const dal::RecordSet &charInfo = mDb->processSql(); // if the character is not even in the database then // we have no choice but to return nothing. @@ -366,15 +378,24 @@ Character *DALStorage::getCharacterBySQL(const std::string &query, Account *owne Character *DALStorage::getCharacter(int id, Account *owner) { std::ostringstream sql; - sql << "select * from " << CHARACTERS_TBL_NAME << " where id = '" << id << "';"; - return getCharacterBySQL(sql.str(), owner); +// sql << "select * from " << CHARACTERS_TBL_NAME << " where id = '" << id << "';"; + sql << "SELECT * FROM " << CHARACTERS_TBL_NAME << " WHERE id = ?"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindInteger(1, id); + } + return getCharacterBySQL(owner); } Character *DALStorage::getCharacter(const std::string &name) { std::ostringstream sql; - sql << "select * from " << CHARACTERS_TBL_NAME << " where name = '" << name << "';"; - return getCharacterBySQL(sql.str(), NULL); + sql << "SELECT * FROM " << CHARACTERS_TBL_NAME << " WHERE name = ?"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, name); + } + return getCharacterBySQL(NULL); } #if 0 @@ -416,9 +437,17 @@ bool DALStorage::doesUserNameExist(const std::string &name) { try { std::ostringstream sql; - sql << "select count(username) from " << ACCOUNTS_TBL_NAME - << " where username = \"" << name << "\";"; - const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); +// sql << "select count(username) from " << ACCOUNTS_TBL_NAME +// << " where username = \"" << name << "\";"; +// const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); + sql << "SELECT COUNT(username) FROM " << ACCOUNTS_TBL_NAME + << " WHERE username = ?"; + + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, name); + } + const dal::RecordSet &accountInfo = mDb->processSql(); std::istringstream ssStream(accountInfo(0, 0)); unsigned int iReturn = 1; @@ -440,9 +469,16 @@ bool DALStorage::doesEmailAddressExist(const std::string &email) { try { std::ostringstream sql; - sql << "select count(email) from " << ACCOUNTS_TBL_NAME - << " where upper(email) = upper(\"" << email << "\");"; - const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); +// sql << "select count(email) from " << ACCOUNTS_TBL_NAME +// << " where upper(email) = upper(\"" << email << "\");"; +// const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); + sql << "SELECT COUNT(email) FROM " << ACCOUNTS_TBL_NAME + << " WHERE UPPER(email) = UPPER(?)"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, email); + } + const dal::RecordSet &accountInfo = mDb->processSql(); std::istringstream ssStream(accountInfo(0, 0)); unsigned int iReturn = 1; @@ -464,9 +500,15 @@ bool DALStorage::doesCharacterNameExist(const std::string& name) { try { std::ostringstream sql; - sql << "select count(name) from " << CHARACTERS_TBL_NAME - << " where name = \"" << name << "\";"; - const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); +// sql << "select count(name) from " << CHARACTERS_TBL_NAME +// << " where name = \"" << name << "\";"; +// const dal::RecordSet &accountInfo = mDb->execSql(sql.str()); + sql << "SELECT COUNT(name) FROM " << CHARACTERS_TBL_NAME << " WHERE name = ?"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, name); + } + const dal::RecordSet &accountInfo = mDb->processSql(); std::istringstream ssStream(accountInfo(0, 0)); int iReturn = 1; @@ -537,7 +579,7 @@ bool DALStorage::updateCharacter(Character *character, try { std::map<int, int>::const_iterator skill_it; - for (skill_it = character->getSkillBegin(); + for (skill_it = character->getSkillBegin(); skill_it != character->getSkillEnd(); skill_it++) { updateExperience(character->getDatabaseID(), skill_it->first, skill_it->second); @@ -659,17 +701,29 @@ void DALStorage::addAccount(Account *account) try { // insert the account. - std::ostringstream sql1; - sql1 << "insert into " << ACCOUNTS_TBL_NAME + std::ostringstream sql; + sql << "insert into " << ACCOUNTS_TBL_NAME << " (username, password, email, level, banned, registration, lastlogin)" - << " values (\"" - << account->getName() << "\", \"" - << account->getPassword() << "\", \"" - << account->getEmail() << "\", " - << account->getLevel() << ", 0, " +// << " values (\"" +// << account->getName() << "\", \"" +// << account->getPassword() << "\", \"" +// << account->getEmail() << "\", " +// << account->getLevel() << ", 0, " +// << account->getRegistrationDate() << ", " +// << account->getLastLogin() << ");"; +// mDb->execSql(sql1.str()); + << " VALUES (?, ?, ?, " << account->getLevel() << ", 0, " << account->getRegistrationDate() << ", " << account->getLastLogin() << ");"; - mDb->execSql(sql1.str()); + + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, account->getName()); + mDb->bindString(2, account->getPassword()); + mDb->bindString(3, account->getEmail()); + } + + mDb->processSql(); account->setID(mDb->getLastId()); mDb->commitTransaction(); @@ -927,15 +981,25 @@ void DALStorage::addGuild(Guild* guild) { std::ostringstream insertSql; insertSql << "insert into " << GUILDS_TBL_NAME - << " (name) " - << " values (\"" - << guild->getName() << "\");"; - mDb->execSql(insertSql.str()); + << " (name) VALUES (?)"; + if (mDb->prepareSql(insertSql.str())) + { + mDb->bindString(1, guild->getName()); + } + //mDb->execSql(insertSql.str()); + mDb->processSql(); std::ostringstream selectSql; - selectSql << "select id from " << GUILDS_TBL_NAME - << " where name = \"" << guild->getName() << "\";"; - const dal::RecordSet& guildInfo = mDb->execSql(selectSql.str()); + selectSql << "SELECT id FROM " << GUILDS_TBL_NAME + << " WHERE name = ?"; + + if (mDb->prepareSql(selectSql.str())) + { + mDb->bindString(1, guild->getName()); + } + //const dal::RecordSet& guildInfo = mDb->execSql(selectSql.str()); + const dal::RecordSet& guildInfo = mDb->processSql(); + string_to<unsigned int> toUint; unsigned id = toUint(guildInfo(0, 0)); guild->setId(id); @@ -1095,9 +1159,16 @@ std::string DALStorage::getQuestVar(int id, const std::string &name) { std::ostringstream query; query << "select value from " << QUESTS_TBL_NAME - << " where owner_id = '" << id << "' and name = '" - << name << "';"; - const dal::RecordSet &info = mDb->execSql(query.str()); +// << " where owner_id = '" << id << "' and name = '" +// << name << "';"; +// const dal::RecordSet &info = mDb->execSql(query.str()); + << " WHERE owner_id = ? AND name = ?"; + if (mDb->prepareSql(query.str())) + { + mDb->bindInteger(1, id); + mDb->bindString(2, name); + } + const dal::RecordSet &info = mDb->processSql(); if (!info.isEmpty()) return info(0, 0); } @@ -1392,9 +1463,15 @@ void DALStorage::storeLetter(Letter *letter) << letter->getReceiver()->getDatabaseID() << ", " << letter->getExpiry() << ", " << time(NULL) << ", " - << "'" << letter->getContents() << "' )"; +// << "'" << letter->getContents() << "' )"; + << "?)"; + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, letter->getContents()); + } - mDb->execSql(sql.str()); + mDb->processSql(); +// mDb->execSql(sql.str()); letter->setId(mDb->getLastId()); // TODO: store attachments in the database @@ -1410,16 +1487,22 @@ void DALStorage::storeLetter(Letter *letter) << " letter_type = '" << letter->getType() << "', " << " expiration_date = '" << letter->getExpiry() << "', " << " sending_date = '" << time(NULL) << "', " - << " letter_text = '" << letter->getContents() << "' " +// << " letter_text = '" << letter->getContents() << "' " + << " letter_text = ? " << " WHERE letter_id = '" << letter->getId() << "'"; - mDb->execSql(sql.str()); + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, letter->getContents()); + } + mDb->processSql(); + //mDb->execSql(sql.str()); if (mDb->getModifiedRows() == 0) { // this should never happen... - LOG_ERROR("(DALStorage::storePost) trying to update nonexsistant letter"); - throw "(DALStorage::storePost) trying to update nonexsistant letter"; + LOG_ERROR("(DALStorage::storePost) trying to update nonexistant letter"); + throw "(DALStorage::storePost) trying to update nonexistant letter"; } // TODO: update attachments in the database @@ -1561,25 +1644,42 @@ void DALStorage::SyncDatabase(void) { std::ostringstream sql; sql << "UPDATE " << ITEMS_TBL_NAME - << " SET name = '" << mDb->escapeSQL(name) << "', " - << " description = '" << mDb->escapeSQL(desc) << "', " +// << " SET name = '" << mDb->escapeSQL(name) << "', " + << " SET name = ?, " +// << " description = '" << mDb->escapeSQL(desc) << "', " + << " description = ?, " << " image = '" << image << "', " << " weight = " << weight << ", " << " itemtype = '" << type << "', " - << " effect = '" << mDb->escapeSQL(eff) << "', " +// << " effect = '" << mDb->escapeSQL(eff) << "', " + << " effect = ?, " << " dyestring = '" << dye << "' " << " WHERE id = " << id; - mDb->execSql(sql.str()); +// mDb->execSql(sql.str()); + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, name); + mDb->bindString(2, desc); + mDb->bindString(3, eff); + } + mDb->processSql(); if (mDb->getModifiedRows() == 0) { sql.clear(); sql.str(""); sql << "INSERT INTO " << ITEMS_TBL_NAME - << " VALUES ( " << id << ", '" << name << "', '" - << desc << "', '" << image << "', " << weight << ", '" - << type << "', '" << eff << "', '" << dye << "' )"; - mDb->execSql(sql.str()); + << " VALUES ( " << id << ", ?, ?, '" + << image << "', " << weight << ", '" + << type << "', ?, '" << dye << "' )"; + //mDb->execSql(sql.str()); + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, name); + mDb->bindString(2, desc); + mDb->bindString(3, eff); + } + mDb->processSql(); } itmCount++; } @@ -1637,9 +1737,17 @@ void DALStorage::addTransaction(const Transaction &trans) { std::stringstream sql; sql << "INSERT INTO " << TRANSACTION_TBL_NAME - << " VALUES (NULL, " << trans.mCharacterId << ", " << trans.mAction - << ", '" << trans.mMessage << "', " << time(NULL) << ")"; - mDb->execSql(sql.str()); + << " VALUES (NULL, " << trans.mCharacterId << ", " + << trans.mAction << ", " + << "?, " + << time(NULL) << ")"; +// << ", '" << trans.mMessage << "', " << time(NULL) << ")"; +// mDb->execSql(sql.str()); + if (mDb->prepareSql(sql.str())) + { + mDb->bindString(1, trans.mMessage); + } + mDb->processSql(); } catch (const dal::DbSqlQueryExecFailure &e) { diff --git a/src/account-server/dalstorage.hpp b/src/account-server/dalstorage.hpp index 8778b64c..77222377 100644 --- a/src/account-server/dalstorage.hpp +++ b/src/account-server/dalstorage.hpp @@ -395,24 +395,21 @@ class DALStorage operator=(const DALStorage& rhs); /** - * Gets an account by using a SQL query string. + * Gets an account from a prepared SQL statement * - * @param query the query for the account - * - * @return the account found by the query + * @return the account found */ - Account *getAccountBySQL(const std::string &query); + Account *getAccountBySQL(); /** - * Gets a character by character name. + * Gets a character from a prepared SQL statement * - * @param query the query for the character. * @param owner the account the character is in. * * @return the character found by the query. */ - Character *getCharacterBySQL(const std::string &query, Account *owner); + Character *getCharacterBySQL(Account *owner); /** * Synchronizes the base data in the connected SQL database with the xml diff --git a/src/dal/dataprovider.cpp b/src/dal/dataprovider.cpp index e903d0ca..8fd8fe0f 100644 --- a/src/dal/dataprovider.cpp +++ b/src/dal/dataprovider.cpp @@ -71,7 +71,7 @@ DataProvider::getDbName(void) return mDbName; } - +/* std::string& DataProvider::escapeSQL(std::string &sql) { size_t pos = 0; @@ -86,5 +86,6 @@ std::string& DataProvider::escapeSQL(std::string &sql) return sql; } +*/ } // namespace dal diff --git a/src/dal/dataprovider.h b/src/dal/dataprovider.h index f9509492..65de8ee5 100644 --- a/src/dal/dataprovider.h +++ b/src/dal/dataprovider.h @@ -178,12 +178,37 @@ class DataProvider getLastId(void) const = 0; /** - * Takes a SQL snippet and escapes special caharacters like ' to prevent - * SQL injection attacks. - * - * @param sql SQL Snippet to escape. + * Prepare SQL statement + */ + virtual bool prepareSql(const std::string &sql) = 0; + + /** + * Process SQL statement + * SQL statement needs to be prepared and parameters binded before + * calling this function + */ + virtual const RecordSet& processSql() = 0; + + /** + * Bind String + * @param place - which parameter to bind to + * @param value - the string to bind + */ + virtual void bindString(int place, const std::string &value) = 0; + + /** + * Bind Integer + * @param place - which parameter to bind to + * @param value - the integer to bind + */ + virtual void bindInteger(int place, int value) = 0; + + /** + * Bind Float + * @param place - which parameter to bind to + * @param value - the float to bind */ - std::string& escapeSQL(std::string &sql); + virtual void bindFloat(int place, float value) = 0; protected: std::string mDbName; /**< the database name */ diff --git a/src/dal/sqlitedataprovider.cpp b/src/dal/sqlitedataprovider.cpp index 1eaac780..accf979b 100644 --- a/src/dal/sqlitedataprovider.cpp +++ b/src/dal/sqlitedataprovider.cpp @@ -370,4 +370,65 @@ SqLiteDataProvider::getLastId(void) const return (unsigned int)lastId; } +bool SqLiteDataProvider::prepareSql(const std::string &sql) +{ + if (!mIsConnected) + return false; + + LOG_DEBUG("Preparing SQL statement: "<<sql); + + mRecordSet.clear(); + + if (sqlite3_prepare_v2(mDb, sql.c_str(), sql.size(), &mStmt, NULL) != SQLITE_OK) + { + return false; + } + + return true; +} + +const RecordSet& SqLiteDataProvider::processSql() +{ + if (!mIsConnected) { + throw std::runtime_error("not connected to database"); + } + + int totalCols = sqlite3_column_count(mStmt); + Row fieldNames; + + while (sqlite3_step(mStmt) == SQLITE_ROW) + { + Row r; + for (int col = 0; col < totalCols; ++col) + { + fieldNames.push_back(sqlite3_column_name(mStmt, col)); + r.push_back((char*)sqlite3_column_text(mStmt, col)); + } + // ensure we set column headers before adding a row + mRecordSet.setColumnHeaders(fieldNames); + mRecordSet.add(r); + } + + + + sqlite3_finalize(mStmt); + + return mRecordSet; +} + +void SqLiteDataProvider::bindString(int place, const std::string &value) +{ + sqlite3_bind_text(mStmt, place, value.c_str(), value.size(), SQLITE_STATIC); +} + +void SqLiteDataProvider::bindInteger(int place, int value) +{ + sqlite3_bind_int(mStmt, place, value); +} + +void SqLiteDataProvider::bindFloat(int place, float value) +{ + sqlite3_bind_double(mStmt, place, value); +} + } // namespace dal diff --git a/src/dal/sqlitedataprovider.h b/src/dal/sqlitedataprovider.h index 8950f312..3f1951d2 100644 --- a/src/dal/sqlitedataprovider.h +++ b/src/dal/sqlitedataprovider.h @@ -146,6 +146,39 @@ class SqLiteDataProvider: public DataProvider const unsigned int getLastId(void) const; + /** + * Prepare SQL statement + */ + bool prepareSql(const std::string &sql); + + /** + * Process SQL statement + * SQL statement needs to be prepared and parameters binded before + * calling this function + */ + const RecordSet& processSql(); + + /** + * Bind String + * @param place - which parameter to bind to + * @param value - the string to bind + */ + void bindString(int place, const std::string &value); + + /** + * Bind Integer + * @param place - which parameter to bind to + * @param value - the integer to bind + */ + void bindInteger(int place, int value); + + /** + * Bind Float + * @param place - which parameter to bind to + * @param value - the float to bind + */ + void bindFloat(int place, float value); + private: /** defines the name of the database config parameter */ @@ -162,7 +195,8 @@ class SqLiteDataProvider: public DataProvider const bool inTransaction(void) const; - sqlite3* mDb; /**< the handle to the database connection */ + sqlite3 *mDb; /**< the handle to the database connection */ + sqlite3_stmt *mStmt; /**< the prepared statement to process */ }; |