Cuberite
A lightweight, fast and extensible game server for Minecraft
BlockingSslClientSocket.cpp
Go to the documentation of this file.
1 
2 // BlockingSslClientSocket.cpp
3 
4 // Implements the cBlockingSslClientSocket class representing a blocking TCP socket with client SSL encryption over it
5 
6 #include "Globals.h"
8 
9 
10 
11 
12 
14 // cBlockingSslClientSocketConnectCallbacks:
15 
18 {
21 
22  virtual void OnConnected(cTCPLink & a_Link) override
23  {
25  }
26 
27  virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
28  {
29  m_Socket.OnConnectError(a_ErrorMsg);
30  }
31 
32 public:
34  m_Socket(a_Socket)
35  {
36  }
37 };
38 
39 
40 
41 
42 
44 // cBlockingSslClientSocketLinkCallbacks:
45 
48 {
50 
51  virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
52  {
53  m_Socket.SetLink(a_Link);
54  }
55 
56 
57  virtual void OnReceivedData(const char * a_Data, size_t a_Length) override
58  {
59  m_Socket.OnReceivedData(a_Data, a_Length);
60  }
61 
62 
63  virtual void OnRemoteClosed(void) override
64  {
66  }
67 
68 
69  virtual void OnError(int a_ErrorCode, const AString & a_ErrorMsg) override
70  {
72  }
73 
74 public:
75 
77  m_Socket(a_Socket)
78  {
79  }
80 };
81 
82 
83 
84 
85 
87 // cBlockingSslClientSocket:
88 
90  m_Ssl(*this),
91  m_IsConnected(false)
92 {
93  // Nothing needed yet
94 }
95 
96 
97 
98 
99 
100 bool cBlockingSslClientSocket::Connect(const AString & a_ServerName, UInt16 a_Port)
101 {
102  // If already connected, report an error:
103  if (m_IsConnected)
104  {
105  // TODO: Handle this better - if connected to the same server and port, and the socket is alive, return success
106  m_LastErrorText = "Already connected";
107  return false;
108  }
109 
110  // Connect the underlying socket:
111  m_ServerName = a_ServerName;
112  if (!cNetwork::Connect(a_ServerName, a_Port,
113  std::make_shared<cBlockingSslClientSocketConnectCallbacks>(*this),
114  std::make_shared<cBlockingSslClientSocketLinkCallbacks>(*this))
115  )
116  {
117  return false;
118  }
119 
120  // Wait for the connection to succeed or fail:
121  m_Event.Wait();
122  if (!m_IsConnected)
123  {
124  return false;
125  }
126 
127  // Initialize the SSL:
128  int ret = 0;
129  if (m_Config != nullptr)
130  {
131  ret = m_Ssl.Initialize(m_Config);
132  }
133  else
134  {
135  ret = m_Ssl.Initialize(true);
136  }
137 
138  if (ret != 0)
139  {
140  m_LastErrorText = fmt::format(FMT_STRING("SSL initialization failed: -0x{:x}"), -ret);
141  return false;
142  }
143 
144  // If we have been assigned a trusted CA root cert store, push it into the SSL context:
145  if (!m_ExpectedPeerName.empty())
146  {
147  m_Ssl.SetExpectedPeerName(m_ExpectedPeerName);
148  }
149 
150  ret = m_Ssl.Handshake();
151  if (ret != 0)
152  {
153  m_LastErrorText = fmt::format(FMT_STRING("SSL handshake failed: -0x{:x}"), -ret);
154  return false;
155  }
156 
157  return true;
158 }
159 
160 
161 
162 
163 
165 {
166  ASSERT(!m_IsConnected); // Must be called before connect
167 
168  // Warn if used multiple times, but don't signal an error:
169  if (!m_ExpectedPeerName.empty())
170  {
171  LOGWARNING(
172  "SSL: Trying to set multiple expected peer names, only the last one will be used. %s overwriting the previous %s",
173  a_ExpectedPeerName, m_ExpectedPeerName
174  );
175  }
176 
177  m_ExpectedPeerName = std::move(a_ExpectedPeerName);
178 }
179 
180 
181 
182 
183 
184 void cBlockingSslClientSocket::SetSslConfig(std::shared_ptr<const cSslConfig> a_Config)
185 {
186  ASSERT(!m_IsConnected); // Must be called before connect
187 
188  // Warn if used multiple times, but don't signal an error:
189  if (m_Config != nullptr)
190  {
191  LOGWARNING("SSL: Trying to set multiple configurations, only the last one will be used.");
192  }
193 
194  m_Config = std::move(a_Config);
195 }
196 
197 
198 
199 
200 
201 bool cBlockingSslClientSocket::Send(const void * a_Data, size_t a_NumBytes)
202 {
203  if (!m_IsConnected)
204  {
205  m_LastErrorText = "Socket is closed";
206  return false;
207  }
208 
209  // Keep sending the data until all of it is sent:
210  const char * Data = static_cast<const char *>(a_Data);
211  size_t NumBytes = a_NumBytes;
212  for (;;)
213  {
214  int res = m_Ssl.WritePlain(Data, a_NumBytes);
215  if (res < 0)
216  {
217  ASSERT(res != MBEDTLS_ERR_SSL_WANT_READ); // This should never happen with callback-based SSL
218  ASSERT(res != MBEDTLS_ERR_SSL_WANT_WRITE); // This should never happen with callback-based SSL
219  m_LastErrorText = fmt::format(FMT_STRING("Data cannot be written to SSL context: -0x{:x}"), -res);
220  return false;
221  }
222  else
223  {
224  Data += res;
225  NumBytes -= static_cast<size_t>(res);
226  if (NumBytes == 0)
227  {
228  return true;
229  }
230  }
231  }
232 }
233 
234 
235 
236 
237 
238 int cBlockingSslClientSocket::Receive(void * a_Data, size_t a_MaxBytes)
239 {
240  // Even if m_IsConnected is false (socket disconnected), the SSL context may have more data in the queue
241  int res = m_Ssl.ReadPlain(a_Data, a_MaxBytes);
242  if (res < 0)
243  {
244  m_LastErrorText = fmt::format(FMT_STRING("Data cannot be read from SSL context: -0x{:x}"), -res);
245  }
246  return res;
247 }
248 
249 
250 
251 
252 
254 {
255  // Ignore if not connected
256  if (!m_IsConnected)
257  {
258  return;
259  }
260 
261  m_Ssl.NotifyClose();
262  m_IsConnected = false;
263 
264  // Grab a copy of the socket so that we know it doesn't change under our hands:
265  auto socket = m_Socket;
266  if (socket != nullptr)
267  {
268  socket->Close();
269  }
270 
271  m_Socket.reset();
272 }
273 
274 
275 
276 
277 
278 int cBlockingSslClientSocket::ReceiveEncrypted(unsigned char * a_Buffer, size_t a_NumBytes)
279 {
280  // Wait for any incoming data, if there is none:
282  while (m_IsConnected && m_IncomingData.empty())
283  {
284  cCSUnlock Unlock(Lock);
285  m_Event.Wait();
286  }
287 
288  // If we got disconnected, report an error after processing all data:
289  if (!m_IsConnected && m_IncomingData.empty())
290  {
292  }
293 
294  // Copy the data from the incoming buffer into the specified space:
295  size_t NumToCopy = std::min(a_NumBytes, m_IncomingData.size());
296  memcpy(a_Buffer, m_IncomingData.data(), NumToCopy);
297  m_IncomingData.erase(0, NumToCopy);
298  return static_cast<int>(NumToCopy);
299 }
300 
301 
302 
303 
304 
305 int cBlockingSslClientSocket::SendEncrypted(const unsigned char * a_Buffer, size_t a_NumBytes)
306 {
307  cTCPLinkPtr Socket(m_Socket); // Make a copy so that multiple threads don't race on deleting the socket.
308  if (Socket == nullptr)
309  {
311  }
312  if (!Socket->Send(a_Buffer, a_NumBytes))
313  {
314  // mbedTLS's net routines distinguish between connection reset and general failure, we don't need to
316  }
317  return static_cast<int>(a_NumBytes);
318 }
319 
320 
321 
322 
323 
325 {
326  m_IsConnected = true;
327  m_Event.Set();
328 }
329 
330 
331 
332 
333 
335 {
336  LOG("Cannot connect to %s: \"%s\"", m_ServerName.c_str(), a_ErrorMsg.c_str());
337  m_Event.Set();
338 }
339 
340 
341 
342 
343 
344 void cBlockingSslClientSocket::OnReceivedData(const char * a_Data, size_t a_Size)
345 {
346  {
348  m_IncomingData.append(a_Data, a_Size);
349  }
350  m_Event.Set();
351 }
352 
353 
354 
355 
356 
358 {
359  m_Socket = std::move(a_Link);
360 }
361 
362 
363 
364 
365 
367 {
368  m_IsConnected = false;
369  m_Socket.reset();
370  m_Event.Set();
371 }
372 
373 
374 
375 
#define ASSERT(x)
Definition: Globals.h:276
unsigned short UInt16
Definition: Globals.h:158
void LOGWARNING(std::string_view a_Format, const Args &... args)
Definition: LoggerSimple.h:67
void LOG(std::string_view a_Format, const Args &... args)
Definition: LoggerSimple.h:55
#define MBEDTLS_ERR_NET_RECV_FAILED
Reading information from the socket failed.
Definition: ErrorCodes.h:10
#define MBEDTLS_ERR_NET_SEND_FAILED
Sending information through the socket failed.
Definition: ErrorCodes.h:11
std::shared_ptr< cTCPLink > cTCPLinkPtr
Definition: Network.h:25
std::string AString
Definition: StringUtils.h:11
virtual void OnConnected(cTCPLink &a_Link) override
Called when the Connect call succeeds.
cBlockingSslClientSocketConnectCallbacks(cBlockingSslClientSocket &a_Socket)
cBlockingSslClientSocket & m_Socket
The socket object that is using this instance of the callbacks.
virtual void OnError(int a_ErrorCode, const AString &a_ErrorMsg) override
Called when the Connect call fails.
virtual void OnRemoteClosed(void) override
Called when the remote end closes the connection.
virtual void OnLinkCreated(cTCPLinkPtr a_Link) override
Called when the cTCPLink for the connection is created.
virtual void OnError(int a_ErrorCode, const AString &a_ErrorMsg) override
Called when an error is detected on the connection.
virtual void OnReceivedData(const char *a_Data, size_t a_Length) override
Called when there's data incoming from the remote peer.
cBlockingSslClientSocketLinkCallbacks(cBlockingSslClientSocket &a_Socket)
cCallbackSslContext m_Ssl
The SSL context used for the socket.
void Disconnect(void)
Disconnects the connection gracefully, if possible.
bool Connect(const AString &a_ServerName, UInt16 a_Port)
Connects to the specified server and performs SSL handshake.
cTCPLinkPtr m_Socket
The underlying socket to the SSL server.
void SetSslConfig(std::shared_ptr< const cSslConfig > a_Config)
Set the config to be used by the SSL context.
void OnConnected(void)
Called when the connection is established successfully.
AString m_ServerName
The hostname to which the socket is connecting (stored for error reporting).
int Receive(void *a_Data, size_t a_MaxBytes)
Receives data from the connection.
virtual int ReceiveEncrypted(unsigned char *a_Buffer, size_t a_NumBytes) override
Called when mbedTLS wants to read encrypted data from the SSL peer.
virtual int SendEncrypted(const unsigned char *a_Buffer, size_t a_NumBytes) override
Called when mbedTLS wants to write encrypted data to the SSL peer.
void SetExpectedPeerName(AString a_ExpectedPeerName)
Sets the Expected peer name.
std::atomic< bool > m_IsConnected
Set to true if the connection established successfully.
void OnConnectError(const AString &a_ErrorMsg)
Called when an error occurs while connecting the socket.
std::shared_ptr< const cSslConfig > m_Config
The configuration to be used by the SSL context.
cEvent m_Event
The object used to signal state changes in the socket (the cause of the blocking).
bool Send(const void *a_Data, size_t a_NumBytes)
Sends the specified data over the connection.
AString m_ExpectedPeerName
The expected SSL peer's name, if we are to verify the cert strictly.
cCriticalSection m_CSIncomingData
Protects m_IncomingData against multithreaded access.
AString m_IncomingData
Buffer for the data incoming on the network socket.
void OnReceivedData(const char *a_Data, size_t a_Size)
Called when there's incoming data from the socket.
void OnDisconnected(void)
Called when the link is disconnected, either gracefully or by an error.
void SetLink(cTCPLinkPtr a_Link)
Called when the link for the connection is created.
AString m_LastErrorText
Text of the last error that has occurred.
RAII for cCriticalSection - locks the CS on creation, unlocks on destruction.
Temporary RAII unlock for a cCSLock.
void Wait(void)
Waits until the event has been set.
Definition: Event.cpp:23
void Set(void)
Sets the event - releases one thread that has been waiting in Wait().
Definition: Event.cpp:52
Interface that provides the methods available on a single TCP connection.
Definition: Network.h:42
static bool Connect(const AString &a_Host, UInt16 a_Port, cConnectCallbacksPtr a_ConnectCallbacks, cTCPLink::cCallbacksPtr a_LinkCallbacks)
Queues a TCP connection to be made to the specified host.
Callbacks used for connecting to other servers as a client.
Definition: Network.h:237