MemCacheClient

MemCacheClient.cpp

Go to the documentation of this file.
00001 
00005 //#include "stdafx.h"
00006 
00007 #include <algorithm>
00008 
00009 #ifdef _WIN32
00010 # include <winsock2.h>
00011 # define strtoull _strtoui64
00012 #else
00013 # include <stdint.h>
00014 #endif
00015 
00016 // If OpenSSL is available, better to use it
00017 //#include <openssl/sha.h>
00018 #include "sha1.h"
00019 #ifndef HEADER_SHA_H //openssl
00020 #define SHA_DIGEST_LENGTH   SHA1_DIGEST_LENGTH
00021 void SHA1(const unsigned char *d, size_t n, unsigned char *md) {
00022     SHA1((sha1_byte*)md, (const sha1_byte*)d, (unsigned int)n);
00023 }
00024 #endif
00025 
00026 // local
00027 #include "Socket.h"
00028 #include "MemCacheClient.h"
00029 
00030 // lib
00031 #ifdef CROSSBASE_API
00032 # include <xplatform/timer.h>
00033 # include <Trace/cltrace.h>
00034 START_CL_NAMESPACE
00035 #else
00036 # include "Matilda.h"
00037 #endif
00038 
00040 
00042 class MemCacheClient::Server : public Socket
00043 {
00044 public:
00046     Server(const ClTrace & aTrace) 
00047         : Socket(aTrace)
00048         , mIp(INADDR_NONE)
00049         , mPort(0)
00050         , mLastConnect(0) 
00051     { 
00052         mAddress[0] = 0; 
00053     }
00054     
00058     Server(const Server & rhs) : Socket(rhs.mTrace) { operator=(rhs); }
00059     
00061     ~Server() { }
00062 
00066     Server & operator=(const Server & rhs);
00067     
00072     bool operator==(const Server & rhs) const;
00073     
00078     inline bool operator!=(const Server & rhs) const { return !operator==(rhs); }
00079     
00084     bool Set(const char * aServer); 
00085     
00086     enum ConnectResult { CONNECT_SUCCESS, CONNECT_FAILED, CONNECT_WAITING };
00087 
00092     ConnectResult Connect(size_t aTimeout, size_t aRetryPeriod);
00093     
00101     inline const char * GetAddress() const { return mAddress; }
00102 
00104     inline int GetPort() const { return mPort; }
00105 
00106 private:
00108     const static size_t ADDRLEN = sizeof("aaa.bbb.ccc.ddd:PPPPP");
00109 
00110     char            mAddress[ADDRLEN];  
00111     unsigned long   mIp;                
00112     int             mPort;              
00113     unsigned long   mLastConnect;       
00114 };
00115 
00116 MemCacheClient::Server & 
00117 MemCacheClient::Server::operator=(
00118     const Server & rhs
00119     ) 
00120 {
00121     if (this != &rhs) {
00122         mTrace = rhs.mTrace;
00123         strcpy(mAddress, rhs.mAddress);
00124         mIp   = rhs.mIp;
00125         mPort = rhs.mPort;
00126         mLastConnect = 0;
00127     }
00128     return *this;
00129 }
00130 
00131 bool 
00132 MemCacheClient::Server::operator==(
00133     const Server & rhs
00134     ) const
00135 {
00136     return mIp == rhs.mIp && mPort == rhs.mPort;
00137 }
00138 
00139 bool 
00140 MemCacheClient::Server::Set(
00141     const char * aServer
00142     ) 
00143 {
00144     if (!aServer || !*aServer) return false;
00145 
00146     char server[200];
00147     size_t nLen = strlen(aServer);
00148     if (nLen >= sizeof(server)) return false; 
00149     strcpy(server, aServer);
00150 
00151     mPort = 11211;
00152     char * pszPort = strchr(server, ':');
00153     if (pszPort) {
00154         mPort = atoi(pszPort + 1);
00155         *pszPort = 0;
00156     }
00157 
00158     mIp = inet_addr(server);
00159     if (mIp == INADDR_NONE) return false;
00160 
00161     struct in_addr addr;
00162     addr.s_addr = mIp;
00163     snprintf(mAddress, ADDRLEN, "%s:%d", inet_ntoa(addr), mPort);
00164 
00165     return true;
00166 }
00167 
00168 MemCacheClient::Server::ConnectResult
00169 MemCacheClient::Server::Connect(
00170     size_t aTimeout,
00171     size_t aRetryPeriod
00172     ) 
00173 {
00174     // already connected? do nothing
00175     if (Socket::IsConnected()) {
00176         return CONNECT_SUCCESS;
00177     }
00178 
00179     struct in_addr addr;
00180     addr.s_addr = mIp;
00181     const char * pszAddress = inet_ntoa(addr);
00182 
00183     // only try to re-connect to a broken server occasionally if it is optional. 
00184     // a required server will be attempted every time.
00185     unsigned long nNow = xplatform::GetCurrentTickCount();
00186     if (mLastConnect && (nNow - mLastConnect) < aRetryPeriod) {
00187         mTrace.Trace(CLDEBUG, "Connection attempt to %s:%d ignored (last failed attempt %lu seconds ago)",
00188             pszAddress, mPort, (nNow - mLastConnect) / 1000);
00189         return CONNECT_WAITING;
00190     }
00191     mLastConnect = nNow;
00192 
00193     try {
00194         // use a decent size socket buffer 
00195         mBufferSize = 32 * 1024;
00196         mConnectTimeout = (int) aTimeout;
00197         mSendTimeout = (int) aTimeout;
00198         mRecvTimeout = (int) aTimeout;
00199         Socket::Connect(pszAddress, mPort);
00200     }
00201     catch (const Socket::Exception &) { 
00202         // message already logged
00203         return CONNECT_FAILED;
00204     }
00205 
00206     return CONNECT_SUCCESS;
00207 }
00208 
00210 // ConsistentHash
00211 //
00212 
00213 bool MemCacheClient::ConsistentHash::operator<(const MemCacheClient::ConsistentHash & rhs) const 
00214 { 
00215     if (mHash != rhs.mHash) {
00216         return mHash < rhs.mHash; 
00217     }
00218 
00219     // in case we get multiple servers with the same hash, compare the actual server
00220     // addresses to get a consistent ordering
00221     if (mServer != rhs.mServer) {
00222         return strcmp(mServer->GetAddress(), rhs.mServer->GetAddress()) < 0;
00223     }
00224 
00225     return mEntry == rhs.mEntry;
00226 }
00227 
00229 struct MemCacheClient::ConsistentHash::MatchServer
00230 {
00232     MemCacheClient::Server * mServer; 
00233 
00237     MatchServer(MemCacheClient::Server * aServer) : mServer(aServer) { }
00238 
00242     bool operator()(const ConsistentHash & rhs) const { return rhs.mServer == mServer; }
00243 };
00244 
00246 // MemCacheClient
00247 
00248 MemCacheClient::MemCacheClient()
00249     : mTrace("MEMCACHE")
00250     , mTimeoutMs(1000)
00251     , mRetryMs(300 * 1000)
00252 {
00253 }
00254 
00255 MemCacheClient::~MemCacheClient()
00256 {
00257     ClearServers();
00258 }
00259 
00260 void
00261 MemCacheClient::ClearServers()
00262 {
00263     for (size_t n = 0; n < mServer.size(); ++n) {
00264         delete mServer[n];
00265     }
00266     mServer.clear();
00267 }
00268 
00269 const char * 
00270 MemCacheClient::ConvertResult(
00271     MCResult aResult
00272     ) 
00273 {
00274     switch (aResult) {
00275     case MCERR_OK:        return "MCERR_OK";
00276     case MCERR_NOREPLY:   return "MCERR_NOREPLY";
00277     case MCERR_NOTSTORED: return "MCERR_NOTSTORED";
00278     case MCERR_NOTFOUND:  return "MCERR_NOTFOUND";
00279     case MCERR_NOSERVER:  return "MCERR_NOSERVER";
00280     default:              return "(unknown)";
00281     }
00282 }
00283 
00284 bool 
00285 MemCacheClient::AddServer(
00286     const char *    aServerAddress,
00287     const char *    aServerName,
00288     unsigned        aServices
00289     )
00290 {
00291     if (!aServerName) {
00292         aServerName = aServerAddress;
00293     }
00294 
00295     // if we the server address is valid then we allow the server 
00296     // to be added. All servers being added are assumed to be available
00297     // or to be soon made available. 
00298     Server * pServer = new Server(mTrace);
00299     if (!pServer->Set(aServerAddress)) {
00300         mTrace.Trace(CLERROR, "Ignoring invalid server: %s (%s)", 
00301             aServerAddress, aServerName);
00302         delete pServer;
00303         return false;
00304     }
00305     for (size_t n = 0; n < mServer.size(); ++n) {
00306         if (*pServer == *mServer[n]) {
00307             mTrace.Trace(CLERROR, "Ignoring duplicate server: %s (%s)", 
00308                 aServerAddress, aServerName);
00309             return true; // already have it
00310         }
00311     }
00312     mServer.push_back(pServer);
00313 
00314     // for each salt we generate a string hash for the consistent hash 
00315     // table. To ensure stability of the hashing for multiple servers, 
00316     // we want to have a number of entries for each server. 
00317     static const char * rgpSalt[] = {
00318         "{DEA60AAB-CFF9-4a20-A799-4E5E93369656}",
00319         "{C05167CC-57DA-40f2-9EB8-18F65E56FD21}",
00320         "{57939537-0966-49e7-B675-ACE63246BFA5}",
00321         "{F0C8BE5C-A0F1-478f-BC45-28D42AF0CA1E}"
00322     };
00323 
00324     string_t sKey;
00325     ConsistentHash entry(0, pServer, aServices, 0);
00326     for (size_t n = 0; n < sizeof(rgpSalt)/sizeof(rgpSalt[0]); ++n) {
00327         sKey  = pServer->GetAddress();
00328         sKey += rgpSalt[n];
00329         entry.mEntry++;
00330         entry.mHash = CreateKeyHash(sKey.data());
00331         mServerHash.push_back(entry);
00332     }
00333 
00334     // sort the vector so that we can binary search it
00335     std::sort(mServerHash.begin(), mServerHash.end());
00336 
00337     mTrace.Trace(CLINFO, "Adding server: %s (%s:%u), services: 0x%x",
00338         aServerAddress, aServerName, pServer->GetPort(), aServices);
00339     return true;
00340 }
00341 
00342 void
00343 MemCacheClient::DumpTables()
00344 {
00345     // we need this information to ensure that different servers are
00346     // using the same consistent hashing tables.
00347     if (!mTrace.IsThisModuleTracing(CLDEBUG)) {
00348         return;
00349     }
00350 
00351     std::string verify;
00352     char buf[200];
00353     mTrace.Trace(CLDEBUG, "Consistent Hash Server Ring (%u entries):", mServerHash.size());
00354     for (size_t n = 0; n < mServerHash.size(); ++n) {
00355         const ConsistentHash & server = mServerHash[n];
00356         mTrace.Trace(CLDEBUG, "%2u: %08lx = %s (services: 0x%x, entry: %d)", 
00357             n, server.mHash, server.mServer->GetAddress(), server.mServices, server.mEntry);
00358         snprintf(buf, sizeof(buf), "%s>%d>%x>%lx>", server.mServer->GetAddress(), 
00359             server.mEntry, server.mServices, server.mHash);
00360         verify += buf;
00361     }
00362 
00363     mTrace.Trace(CLDEBUG, "Data verification code: %lx", CreateKeyHash(verify.c_str()));
00364 }
00365 
00366 bool 
00367 MemCacheClient::DelServer(
00368     const char * aServer
00369     )
00370 {
00371     Server test(mTrace);
00372     if (test.Set(aServer)) {
00373         std::vector<Server*>::iterator i = mServer.begin();
00374         for (; i != mServer.end(); ++i) {
00375             Server * pServer = *i;
00376             if (test != *pServer) continue;
00377 
00378             delete pServer;
00379             mServer.erase(i);
00380             ConsistentHash::MatchServer server(pServer);
00381             mServerHash.erase(
00382                 std::partition(mServerHash.begin(), mServerHash.end(), server), 
00383                 mServerHash.end());
00384             std::sort(mServerHash.begin(), mServerHash.end());
00385             return true;
00386         }
00387     }
00388 
00389     // not found
00390     return false;
00391 }
00392 
00393 void 
00394 MemCacheClient::GetServers(
00395     std::vector<string_t> & aServers
00396     )
00397 {
00398     string_t address;
00399     aServers.clear();
00400     aServers.reserve(mServer.size());
00401     for (size_t n = 0; n < mServer.size(); ++n) {
00402         address = mServer[n]->GetAddress();
00403         aServers.push_back(address);
00404     }
00405 }
00406 
00407 void 
00408 MemCacheClient::SetTimeout(
00409     size_t aTimeoutMs
00410     )
00411 {
00412     mTimeoutMs = aTimeoutMs;
00413 }
00414 
00415 void 
00416 MemCacheClient::SetRetryPeriod(
00417     size_t aRetryMs
00418     )
00419 {
00420     mRetryMs = aRetryMs;
00421 }
00422 
00423 unsigned long 
00424 MemCacheClient::CreateKeyHash(
00425     const char * aKey
00426     )
00427 {
00428     const size_t LONG_COUNT = SHA_DIGEST_LENGTH / sizeof(unsigned long);
00429     
00430     union {
00431         unsigned char as_char[SHA_DIGEST_LENGTH];
00432         unsigned long as_long[LONG_COUNT];
00433     } output;
00434 
00435     CR_ASSERT(sizeof(output.as_char) == SHA_DIGEST_LENGTH);
00436     CR_ASSERT(sizeof(output.as_long) == SHA_DIGEST_LENGTH);
00437 
00438     SHA1((const unsigned char *) aKey, (unsigned long) strlen(aKey), output.as_char);
00439     return output.as_long[LONG_COUNT-1];
00440 }
00441 
00442 MemCacheClient::Server *
00443 MemCacheClient::FindServer(
00444     const string_t & aKey,
00445     unsigned         aService
00446     )
00447 {
00448 #ifdef CROSSBASE_API
00449     // in our private usage of this, the service must never be 0
00450     if (aService == 0) {
00451         mTrace.Trace(CLERROR, "FindServer: no service requested, supplied cache server may not be appropriate!!!");
00452         CR_ASSERT(!"FindServer: no service requested, supplied cache server may not be appropriate!!!");
00453     }
00454 #endif
00455 
00456     // probably need some servers for this
00457     if (mServerHash.empty()) {
00458         //mTrace.Trace(CLDEBUG, "FindServer: server hash is empty");
00459         return NULL;
00460     }
00461 
00462     // find the next largest consistent hash value above this key hash
00463     ConsistentHash hash(CreateKeyHash(aKey.data()), NULL, 0, 0);
00464     std::vector<ConsistentHash>::iterator iBegin = mServerHash.begin();
00465     std::vector<ConsistentHash>::iterator iEnd = mServerHash.end();
00466     std::vector<ConsistentHash>::iterator iCurr = std::lower_bound(iBegin, iEnd, hash);
00467     if (iCurr == iEnd) iCurr = iBegin;
00468 
00469     // now find the next server that handles this service
00470     if (aService != 0) {
00471         //int nSkipped = 0;
00472         std::vector<ConsistentHash>::iterator iStart = iCurr;
00473         while (!iCurr->services(aService)) {
00474             //++nSkipped;
00475             ++iCurr; 
00476             if (iCurr == iEnd) iCurr = iBegin;
00477             if (iCurr == iStart) {
00478                 mTrace.Trace(CLDEBUG, "FindServer: no server for required service: %u", aService);
00479                 return NULL;
00480             }
00481         }
00482         //if (nSkipped > 0) mTrace.Trace(CLDEBUG, "skipped %d servers for service: %u", nSkipped, aService);
00483     }
00484 
00485     // ensure that this server is connected 
00486     Server * pServer = iCurr->mServer;
00487     Server::ConnectResult rc = pServer->Connect(mTimeoutMs, mRetryMs);
00488     switch (rc) {
00489     case Server::CONNECT_SUCCESS:
00490         //mTrace.Trace(CLDEBUG, "FindServer: using server %s", pServer->GetAddress());
00491         return pServer;
00492     case Server::CONNECT_WAITING:
00493         return NULL;
00494     default:
00495     case Server::CONNECT_FAILED:
00496         //mTrace.Trace(CLDEBUG, "FindServer: failed to connect to server %s", pServer->GetAddress());
00497         return NULL;
00498     }
00499 }
00500 
00502 struct MemCacheClient::MemRequest::Sort 
00503 { 
00509     bool operator()(const MemRequest * pl, const MemRequest * pr) const {
00510         return pl->mServer < pr->mServer; // any order is fine
00511     }
00512 }; 
00513 
00514 int 
00515 MemCacheClient::Combine(
00516     const char *    aType,
00517     MemRequest *    aItem, 
00518     int             aCount
00519     )
00520 {
00521     if (aCount < 1) {
00522         mTrace.Trace(CLDEBUG, "%s: ignoring request for %d items",
00523             aType, aCount);
00524         return 0;
00525     }
00526     CR_ASSERT(*aType == 'g' || *aType == 'd'); // get, gets, del
00527 
00528     MemRequest * rgpItem[MAX_REQUESTS] = { NULL };
00529     if (aCount > MAX_REQUESTS) {
00530         mTrace.Trace(CLDEBUG, "%s: ignoring request for all %d items (too many)", 
00531             aType, aCount);
00532         return -1; // invalid args
00533     }
00534 
00535     // initialize and find all of the servers for these items
00536     int nItemCount = 0;
00537     for (int n = 0; n < aCount; ++n) {
00538         // ensure that the key doesn't have a space in it
00539         CR_ASSERT(NULL == strchr(aItem[n].mKey.data(), ' '));
00540         aItem[n].mServer = FindServer(aItem[n].mKey, aItem[n].mService);
00541         aItem[n].mData.SetEmpty();
00542         if (aItem[n].mServer) {
00543             rgpItem[nItemCount++] = &aItem[n];
00544         }
00545         else {
00546             aItem[n].mResult = MCERR_NOSERVER;
00547         }
00548     }
00549     if (nItemCount == 0) {
00550         mTrace.Trace(CLDEBUG, "%s: ignoring request for all %d items (no servers available)", 
00551             aType, aCount);
00552         return 0;
00553     }
00554 
00555     // sort all requests into server order
00556     const static MemRequest::Sort sortOnServer = MemRequest::Sort();
00557     std::sort(&rgpItem[0], &rgpItem[nItemCount], sortOnServer);
00558 
00559     // send all requests
00560     char szBuf[50];
00561     int nItem = 0, nNext;
00562     string_t sRequest, sTemp;
00563     while (nItem < nItemCount) {
00564         for (nNext = nItem; nNext < nItemCount; ++nNext) {
00565             if (rgpItem[nItem]->mServer != rgpItem[nNext]->mServer) break;
00566             CR_ASSERT(*aType == 'g' || *aType == 'd');
00567             rgpItem[nNext]->mData.SetEmpty();
00568 
00569             // create get request for all keys on this server
00570             if (*aType == 'g') {
00571                 if (nNext == nItem) sRequest = "get";
00572                 else sRequest.resize(sRequest.length() - 2);
00573                 sRequest += ' ';
00574                 sRequest += rgpItem[nNext]->mKey;
00575                 sRequest += "\r\n";
00576                 rgpItem[nNext]->mResult = MCERR_NOTFOUND;
00577             }
00578             // create del request for all keys on this server
00579             else if (*aType == 'd') {
00580                 // delete <key> [<time>] [noreply]\r\n
00581                 sRequest += "delete ";
00582                 sRequest += rgpItem[nNext]->mKey;
00583                 sRequest += ' ';
00584                 snprintf(szBuf, sizeof(szBuf), "%ld", (long) rgpItem[nNext]->mExpiry);
00585                 sRequest += szBuf;
00586                 if (rgpItem[nNext]->mResult == MCERR_NOREPLY) {
00587                     sRequest += " noreply";
00588                 }
00589                 sRequest += "\r\n";
00590                 if (rgpItem[nNext]->mResult != MCERR_NOREPLY) {
00591                     rgpItem[nNext]->mResult = MCERR_NOTFOUND;
00592                 }
00593             }
00594         }
00595 
00596         // send the request. any socket error causes the server connection 
00597         // to be dropped, so we return errors for all requests using that server.
00598         try {
00599             rgpItem[nItem]->mServer->SendBytes(
00600                 sRequest.data(), sRequest.length());
00601         }
00602         catch (const Socket::Exception & e) {
00603             mTrace.Trace(CLINFO, "%s: request error '%s' at %s, marking requests as NOSERVER",
00604                 aType, e.mDetail, rgpItem[nItem]->mServer->GetAddress());
00605             for (int n = nItem; n < nNext; ++n) {
00606                 rgpItem[n]->mServer = NULL;
00607                 rgpItem[n]->mResult = MCERR_NOSERVER;
00608             }
00609         }
00610         nItem = nNext;
00611     }
00612 
00613     // receive responses from all servers
00614     int nResponses = 0;
00615     for (nItem = 0; nItem < nItemCount; nItem = nNext) {
00616         // find the end of this server
00617         if (!rgpItem[nItem]->mServer) { nNext = nItem + 1; continue; }
00618         for (nNext = nItem + 1; nNext < nItemCount; ++nNext) {
00619             if (rgpItem[nItem]->mServer != rgpItem[nNext]->mServer) break;
00620         }
00621 
00622         // receive the responses. any socket error causes the server connection 
00623         // to be dropped, so we return errors for all requests using that server.
00624         try {
00625             if (*aType == 'g') {
00626                 nResponses += HandleGetResponse(
00627                     rgpItem[nItem]->mServer, 
00628                     &rgpItem[nItem], &rgpItem[nNext]);
00629             }
00630             else if (*aType == 'd') {
00631                 nResponses += HandleDelResponse(
00632                     rgpItem[nItem]->mServer, 
00633                     &rgpItem[nItem], &rgpItem[nNext]);
00634             }
00635         }
00636         catch (const Socket::Exception & e) {
00637             mTrace.Trace(CLINFO, "%s: response error '%s' at %s, marking requests as NOSERVER",
00638                 aType, e.mDetail, rgpItem[nItem]->mServer->GetAddress());
00639             rgpItem[nItem]->mServer->Disconnect();
00640             for (int n = nNext - 1; n >= nItem; --n) {
00641                 if (rgpItem[nItem]->mServer != rgpItem[n]->mServer) continue;
00642                 rgpItem[n]->mServer = NULL;
00643                 rgpItem[n]->mResult = MCERR_NOSERVER;
00644             }
00645         }
00646     }
00647 
00648     mTrace.Trace(CLDEBUG, "%s: received %d responses to %d requests",
00649         aType, nResponses, aCount);
00650     return nResponses;
00651 }
00652 
00653 int 
00654 MemCacheClient::HandleGetResponse(
00655     Server *        aServer, 
00656     MemRequest **   aBegin, 
00657     MemRequest **   aEnd
00658     )
00659 {
00660     int nFound = 0;
00661 
00662     std::string sValue;
00663     for (;;) {
00664         // get the value
00665         aServer->ReceiveLine(sValue, false);
00666         if (sValue == "END\r\n") break;
00667 
00668         // if it isn't a value then we are in a bad state
00669         if (0 != strncmp(sValue.data(), "VALUE ", 6)) {
00670             throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at VALUE");
00671         }
00672 
00673         // extract the key
00674         int n = (int) sValue.find(' ', 6);
00675         if (n < 1) throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at key");
00676         std::string sKey(sValue, 6, n - 6);
00677 
00678         // extract the flags
00679         const char * pVal = sValue.data() + n + 1;
00680         unsigned nFlags = (unsigned) strtoul(pVal, (char**) &pVal, 10);
00681         if (*pVal++ != ' ') throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at flags");
00682 
00683         // extract the size
00684         unsigned nBytes = (unsigned) strtoul(pVal, (char**) &pVal, 10);
00685         if (*pVal != ' ' && *pVal != '\r') throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at size");
00686 
00687         // find this key in the array
00688         MemRequest * pItem = NULL; 
00689         for (MemRequest ** p = aBegin; p < aEnd; ++p) {
00690             if ((*p)->mKey == sKey.data()) { pItem = *p; break; }
00691         }
00692         if (!pItem) { // key not found, discard the response
00693             aServer->DiscardBytes(nBytes + 2); // +2 == include final "\r\n"
00694             continue;
00695         }
00696         pItem->mFlags = nFlags;
00697 
00698         // extract the cas
00699         if (*pVal == ' ') {
00700             char * last = NULL;
00701             pItem->mCas = strtoull(++pVal, &last, 10);
00702             if (*last != '\r') throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at CAS");
00703         }
00704 
00705         // receive the data
00706         while (nBytes > 0) {
00707             char * pBuf = pItem->mData.GetWriteBuffer(nBytes);
00708             int nReceived = aServer->GetBytes(pBuf, nBytes);
00709             pItem->mData.CommitWriteBytes(nReceived);
00710             nBytes -= nReceived;
00711         }
00712         pItem->mResult = MCERR_OK;
00713 
00714         // discard the trailing "\r\n"
00715         if ('\r' != aServer->GetByte() ||
00716             '\n' != aServer->GetByte())
00717         {
00718             throw Socket::Exception(Socket::ERR_OTHER, 0, "bad get response at trail");
00719         }
00720 
00721         ++nFound;
00722     }
00723 
00724     return nFound;
00725 }
00726 
00727 int 
00728 MemCacheClient::HandleDelResponse(
00729     Server *        aServer, 
00730     MemRequest **   aBegin, 
00731     MemRequest **   aEnd
00732     )
00733 {
00734     std::string sValue;
00735     int nResponses = 0;
00736     for (MemRequest ** p = aBegin; p < aEnd; ++p) {
00737         MemRequest * pItem = *p; 
00738 
00739         // no response for this entry
00740         if (pItem->mResult == MCERR_NOREPLY) continue;
00741 
00742         // get the value
00743         aServer->ReceiveLine(sValue, false);
00744 
00745         // success
00746         if (sValue == "DELETED\r\n") {
00747             pItem->mResult = MCERR_OK;
00748             ++nResponses;
00749             continue;
00750         }
00751 
00752         // the item with this key was not found
00753         if (sValue == "NOT_FOUND\r\n") {
00754             pItem->mResult = MCERR_NOTFOUND;
00755             ++nResponses;
00756             continue;
00757         }
00758 
00759         aServer->Disconnect();
00760         throw Socket::Exception(Socket::ERR_OTHER, 0, "bad del response");
00761     }
00762 
00763     return nResponses;
00764 }
00765 
00766 MCResult 
00767 MemCacheClient::IncDec(
00768     const char *    aType, 
00769     unsigned        aService,
00770     const char *    aKey, 
00771     uint64_t *      aNewValue,
00772     uint64_t        aDiff,
00773     bool            aWantReply
00774     )
00775 {
00776     string_t key(aKey);
00777     Server * pServer = FindServer(key, aService);
00778     if (!pServer) return MCERR_NOSERVER;
00779 
00780     char szBuf[50];
00781     string_t sRequest(aType);
00782     sRequest += ' ';
00783     sRequest += aKey;
00784     snprintf(szBuf, sizeof(szBuf), " %" PRIu64, aDiff);
00785     sRequest += szBuf;
00786     if (!aWantReply) {
00787         sRequest += " noreply";
00788     }
00789     sRequest += "\r\n";
00790 
00791     try {
00792         pServer->SendBytes(sRequest.data(), sRequest.length());
00793 
00794         if (!aWantReply) {
00795             return MCERR_NOREPLY;
00796         }
00797 
00798         string_t sValue;
00799         sValue = pServer->GetByte();
00800         while (sValue[sValue.length()-1] != '\n') {
00801             sValue += pServer->GetByte();
00802         }
00803 
00804         if (sValue == "NOT_FOUND\r\n") {
00805             return MCERR_NOTFOUND;
00806         }
00807 
00808         if (aNewValue) {
00809             *aNewValue = strtoull(sValue.data(), NULL, 10);
00810         }
00811         return MCERR_OK;
00812     }
00813     catch (const Socket::Exception & e) {
00814         mTrace.Trace(CLINFO, "IncDec: error '%s' at %s, marking request as NOSERVER",
00815             e.mDetail, pServer->GetAddress());
00816         pServer->Disconnect();
00817         return MCERR_NOSERVER;
00818     }
00819 }
00820 
00821 int 
00822 MemCacheClient::Store(
00823     const char *    aType,
00824     MemRequest *    aItem, 
00825     int             aCount
00826     )
00827 {
00828     if (aCount < 1) {
00829         mTrace.Trace(CLDEBUG, "Store: ignoring request for %d items", aCount);
00830         return 0;
00831     }
00832 
00833     // initialize and find all of the servers for these items
00834     int nItemCount = 0;
00835     for (int n = 0; n < aCount; ++n) {
00836         // ensure that the key doesn't have a space in it
00837         CR_ASSERT(NULL == strchr(aItem[n].mKey.data(), ' '));
00838         aItem[n].mServer = FindServer(aItem[n].mKey, aItem[n].mService);
00839         if (aItem[n].mServer) {
00840             ++nItemCount;
00841         }
00842         else {
00843             aItem[n].mResult = MCERR_NOSERVER;
00844         }
00845     }
00846     if (nItemCount == 0) {
00847         mTrace.Trace(CLDEBUG, "Store: ignoring request for all %d items (no servers available)", 
00848             aCount);
00849         return 0;
00850     }
00851 
00852     char szBuf[50];
00853     int nResponses = 0;
00854     string_t sRequest;
00855     for (int n = 0; n < aCount; ++n) {
00856         if (!aItem[n].mServer) continue;
00857 
00858         // <command name> <key> <flags> <exptime> <bytes> [noreply]\r\n
00859         sRequest  = aType;
00860         sRequest += ' ';
00861         sRequest += aItem[n].mKey;
00862         snprintf(szBuf, sizeof(szBuf), " %u %ld %u", 
00863             aItem[n].mFlags, (long) aItem[n].mExpiry, 
00864             (unsigned)aItem[n].mData.GetReadSize());
00865         sRequest += szBuf;
00866         if (*aType == 'c') { // cas
00867             snprintf(szBuf, sizeof(szBuf), " %" PRIu64, aItem[n].mCas);
00868             sRequest += szBuf;
00869         }
00870         if (aItem[n].mResult == MCERR_NOREPLY) {
00871             sRequest += " noreply";
00872         }
00873         sRequest += "\r\n";
00874 
00875         // send the request. any socket error causes the server connection 
00876         // to be dropped, so we return errors for all requests using that server.
00877         try {
00878             aItem[n].mServer->SendBytes(
00879                 sRequest.data(), sRequest.length());
00880             aItem[n].mServer->SendBytes(
00881                 aItem[n].mData.GetReadBuffer(), 
00882                 aItem[n].mData.GetReadSize());
00883             aItem[n].mServer->SendBytes("\r\n", 2);
00884 
00885             // done with these read bytes
00886             aItem[n].mData.CommitReadBytes(
00887                 aItem[n].mData.GetReadSize());
00888 
00889             // if no reply is required then move on to the next request
00890             if (aItem[n].mResult == MCERR_NOREPLY) {
00891                 continue;
00892             }
00893 
00894             // handle this response
00895             HandleStoreResponse(aItem[n].mServer, aItem[n]);
00896             ++nResponses;
00897         }
00898         catch (const Socket::Exception & e) {
00899             mTrace.Trace(CLINFO, "Store: error '%s' at %s, marking requests as NOSERVER",
00900                 e.mDetail, aItem[n].mServer->GetAddress());
00901             for (int i = aCount - 1; i >= n; --i) {
00902                 if (aItem[n].mServer != aItem[i].mServer) continue;
00903                 aItem[i].mServer = NULL;
00904                 aItem[i].mResult = MCERR_NOSERVER;
00905             }
00906             continue;
00907         }
00908     }
00909 
00910     return nResponses;
00911 }
00912 
00913 void
00914 MemCacheClient::HandleStoreResponse(
00915     Server *        aServer, 
00916     MemRequest &    aItem
00917     )
00918 {
00919     // get the value
00920     std::string sValue;
00921     aServer->ReceiveLine(sValue, false);
00922 
00923     // success
00924     if (sValue == "STORED\r\n") {
00925         aItem.mResult = MCERR_OK;
00926         return;
00927     }
00928 
00929     // data was not stored, but not because of an error. 
00930     // This normally means that either that the condition for 
00931     // an "add" or a "replace" command wasn't met, or that the
00932     // item is in a delete queue.
00933     if (sValue == "NOT_STORED\r\n") {
00934         aItem.mResult = MCERR_NOTSTORED;
00935         return;
00936     }
00937 
00938     // data was not stored, perhaps the key was too long?
00939     if (sValue == "ERROR\r\n") {
00940         aItem.mResult = MCERR_NOTSTORED;
00941         return;
00942     }
00943 
00944     // unknown response, connection may be bad
00945     aServer->Disconnect();
00946     throw Socket::Exception(Socket::ERR_OTHER, 0, "bad store response");
00947 }
00948 
00949 int
00950 MemCacheClient::FlushAll(
00951     const char *    aServer, 
00952     int             aExpiry
00953     )
00954 {
00955     char szRequest[50];
00956     snprintf(szRequest, sizeof(szRequest), 
00957         "flush_all %u\r\n", aExpiry);
00958 
00959     Server test(mTrace);
00960     if (aServer && !test.Set(aServer)) {
00961         return false;
00962     }
00963 
00964     int nSuccess = 0;
00965     for (size_t n = 0; n < mServer.size(); ++n) {
00966         Server * pServer = mServer[n];
00967         if (aServer && *pServer != test) continue;
00968     
00969         // ensure that we are connected
00970         if (pServer->Connect(mTimeoutMs, mRetryMs) != Server::CONNECT_SUCCESS) {
00971             continue;
00972         }
00973 
00974         try {
00975             // request
00976             pServer->SendBytes(szRequest, strlen(szRequest));
00977 
00978             // response
00979             string_t sValue;
00980             sValue = pServer->GetByte();
00981             while (sValue[sValue.length()-1] != '\n') {
00982                 sValue += pServer->GetByte();
00983             }
00984             if (sValue == "OK\r\n") {
00985                 // done
00986                 ++nSuccess;
00987             }
00988             else {
00989                 // unknown response, connection may be bad
00990                 pServer->Disconnect();
00991             }
00992         }
00993         catch (const Socket::Exception &) {
00994             mTrace.Trace(CLINFO, "socket error, ignoring flush request");
00995             // data error
00996         }
00997     }
00998 
00999     return nSuccess;
01000 }
01001 
01002 #ifdef CROSSBASE_API
01003 END_CL_NAMESPACE
01004 #endif