Impala
Impalaistheopensource,nativeanalyticdatabaseforApacheHadoop.
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
TSaslTransport.cpp
Go to the documentation of this file.
1 // This file will be removed when the code is accepted into the Thrift library.
2 /*
3  * Licensed to the Apache Software Foundation (ASF) under one
4  * or more contributor license agreements. See the NOTICE file
5  * distributed with this work for additional information
6  * regarding copyright ownership. The ASF licenses this file
7  * to you under the Apache License, Version 2.0 (the
8  * "License"); you may not use this file except in compliance
9  * with the License. You may obtain a copy of the License at
10  *
11  * http://www.apache.org/licenses/LICENSE-2.0
12  *
13  * Unless required by applicable law or agreed to in writing,
14  * software distributed under the License is distributed on an
15  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16  * KIND, either express or implied. See the License for the
17  * specific language governing permissions and limitations
18  * under the License.
19  */
20 
21 #include "config.h"
22 
23 #ifdef HAVE_SASL_SASL_H
24 #include <stdint.h>
25 #include <sstream>
26 #include <boost/shared_ptr.hpp>
27 #include <boost/scoped_ptr.hpp>
28 
29 #include <thrift/transport/TBufferTransports.h>
31 
32 #include "common/names.h"
33 
34 // Default size, in bytes, for the memory buffer used to stage reads.
35 const int32_t DEFAULT_MEM_BUF_SIZE = 32 * 1024;
36 
37 namespace apache { namespace thrift { namespace transport {
38 
39  TSaslTransport::TSaslTransport(boost::shared_ptr<TTransport> transport)
40  : transport_(transport),
41  memBuf_(new TMemoryBuffer(DEFAULT_MEM_BUF_SIZE)),
42  shouldWrap_(false),
43  isClient_(false) {
44  }
45 
46  TSaslTransport::TSaslTransport(boost::shared_ptr<sasl::TSasl> saslClient,
47  boost::shared_ptr<TTransport> transport)
48  : transport_(transport),
49  memBuf_(new TMemoryBuffer()),
50  sasl_(saslClient),
51  shouldWrap_(false),
52  isClient_(true) {
53  }
54 
56  delete memBuf_;
57  }
58 
60  return transport_->isOpen();
61  }
62 
64  return (transport_->peek());
65  }
66 
68  return sasl_->getUsername();
69  }
70 
72  const uint8_t* payload, const uint32_t length, bool flush) {
73  uint8_t messageHeader[STATUS_BYTES + PAYLOAD_LENGTH_BYTES];
74  uint8_t dummy = 0;
75  if (payload == NULL) {
76  payload = &dummy;
77  }
78  messageHeader[0] = (uint8_t)status;
79  encodeInt(length, messageHeader, STATUS_BYTES);
80  transport_->write(messageHeader, HEADER_LENGTH);
81  transport_->write(payload, length);
82  if (flush) transport_->flush();
83  }
84 
87  uint32_t resLength;
88 
89  // Only client should open the underlying transport.
90  if (isClient_ && !transport_->isOpen()) {
91  transport_->open();
92  }
93 
94  // initiate SASL message
96 
97  // SASL connection handshake
98  while (!sasl_->isComplete()) {
99  uint8_t* message = receiveSaslMessage(&status, &resLength);
100  if (status == TSASL_COMPLETE) {
101  if (isClient_) break; // handshake complete
102  } else if (status != TSASL_OK) {
103  stringstream ss;
104  ss << "Expected COMPLETE or OK, got " << status;
105  throw TTransportException(ss.str());
106  }
107  uint32_t challengeLength;
108  uint8_t* challenge = sasl_->evaluateChallengeOrResponse(
109  message, resLength, &challengeLength);
110  sendSaslMessage(sasl_->isComplete() ? TSASL_COMPLETE : TSASL_OK,
111  challenge, challengeLength);
112  }
113 
114  // If the server isn't complete yet, we need to wait for its response.
115  // This will occur with ANONYMOUS auth, for example, where we send an
116  // initial response and are immediately complete.
117  if (isClient_ && (status == TSASL_INVALID || status == TSASL_OK)) {
118  receiveSaslMessage(&status, &resLength);
119  if (status != TSASL_COMPLETE) {
120  stringstream ss;
121  ss << "Expected COMPLETE or OK, got " << status;
122  throw TTransportException(ss.str());
123  }
124  }
125 
126  // TODO : need to set the shouldWrap_ based on QOP
127  /*
128  String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP);
129  if (qop != null && !qop.equalsIgnoreCase("auth"))
130  shouldWrap_ = true;
131  */
132  }
133 
135  transport_->close();
136  }
137 
139  uint8_t lenBuf[PAYLOAD_LENGTH_BYTES];
140 
141  transport_->readAll(lenBuf, PAYLOAD_LENGTH_BYTES);
142  int32_t len = decodeInt(lenBuf, 0);
143  if (len < 0) {
144  throw TTransportException("Frame size has negative value");
145  }
146  return static_cast<uint32_t>(len);
147  }
148 
150  // readEnd() returns the number of bytes already read, i.e. the number of 'junk' bytes
151  // taking up space at the front of the memory buffer.
152  uint32_t read_end = memBuf_->readEnd();
153 
154  // If the size of the junk space at the beginning of the buffer is too large, and
155  // there's no data left in the buffer to read (number of bytes read == number of bytes
156  // written), then shrink the buffer back to the default. We don't want to do this on
157  // every read that exhausts the buffer, since the layer above often reads in small
158  // chunks, which is why we only resize if there's too much junk. The write and read
159  // pointers will eventually catch up after every RPC, so we will always take this path
160  // eventually once the buffer becomes sufficiently full.
161  //
162  // readEnd() may reset the write / read pointers (but only once if there's no
163  // intervening read or write between calls), so needs to be called a second time to
164  // get their current position.
165  if (read_end > DEFAULT_MEM_BUF_SIZE && memBuf_->writeEnd() == memBuf_->readEnd()) {
166  memBuf_->resetBuffer(DEFAULT_MEM_BUF_SIZE);
167  }
168  }
169 
170  uint32_t TSaslTransport::read(uint8_t* buf, uint32_t len) {
171  uint32_t read_bytes = memBuf_->read(buf, len);
172 
173  if (read_bytes > 0) {
174  shrinkBuffer();
175  return read_bytes;
176  }
177 
178  // if there's not enough data in cache, read from underlying transport
179  uint32_t dataLength = readLength();
180 
181  // Fast path
182  if (len == dataLength && !shouldWrap_) {
183  transport_->readAll(buf, len);
184  return len;
185  }
186 
187  uint8_t* tmpBuf = new uint8_t[dataLength];
188  transport_->readAll(tmpBuf, dataLength);
189  if (shouldWrap_) {
190  tmpBuf = sasl_->unwrap(tmpBuf, 0, dataLength, &dataLength);
191  }
192 
193  // We will consume all the data, no need to put it in the memory buffer.
194  if (len == dataLength) {
195  memcpy(buf, tmpBuf, len);
196  delete[] tmpBuf;
197  return len;
198  }
199 
200  memBuf_->write(tmpBuf, dataLength);
201  memBuf_->flush();
202  delete[] tmpBuf;
203 
204  uint32_t ret = memBuf_->read(buf, len);
205  shrinkBuffer();
206  return ret;
207  }
208 
209  void TSaslTransport::writeLength(uint32_t length) {
210  uint8_t lenBuf[PAYLOAD_LENGTH_BYTES];
211 
212  encodeInt(length, lenBuf, 0);
213  transport_->write(lenBuf, PAYLOAD_LENGTH_BYTES);
214  }
215 
216  void TSaslTransport::write(const uint8_t* buf, uint32_t len) {
217  const uint8_t* newBuf;
218 
219  if (shouldWrap_) {
220  newBuf = sasl_->wrap((uint8_t*)buf, 0, len, &len);
221  } else {
222  newBuf = buf;
223  }
224  writeLength(len);
225  transport_->write(newBuf, len);
226  }
227 
229  transport_->flush();
230  }
231 
233  uint32_t* length) {
234  uint8_t messageHeader[HEADER_LENGTH];
235 
236  // read header
237  transport_->readAll(messageHeader, HEADER_LENGTH);
238 
239  // get payload status
240  *status = (NegotiationStatus)messageHeader[0];
241  if ((*status < TSASL_START) || (*status > TSASL_COMPLETE)) {
242  throw TTransportException("invalid sasl status");
243  } else if (*status == TSASL_BAD || *status == TSASL_ERROR) {
244  throw TTransportException("sasl Peer indicated failure: ");
245  }
246 
247  // get the length
248  *length = decodeInt(messageHeader, STATUS_BYTES);
249 
250  // get payload
251  protoBuf_.reset(new uint8_t[*length]);
252  transport_->readAll(protoBuf_.get(), *length);
253 
254  return protoBuf_.get();
255  }
256 }
257 }
258 }
259 
260 #endif
boost::scoped_array< uint8_t > protoBuf_
Buffer to hold protocol info.
void write(const uint8_t *buf, uint32_t len)
void sendSaslMessage(const NegotiationStatus status, const uint8_t *payload, const uint32_t length, bool flush=true)
const int32_t DEFAULT_MEM_BUF_SIZE
bool isClient_
True if this is a client.
uint32_t decodeInt(uint8_t *buf, uint32_t offset)
boost::shared_ptr< sasl::TSasl > sasl_
bool shouldWrap_
IF true we wrap data in encryption.
TMemoryBuffer * memBuf_
Buffer for reading and writing.
boost::shared_ptr< TTransport > transport_
Underlying transport.
static const int HEADER_LENGTH
static const int PAYLOAD_LENGTH_BYTES
uint32_t read(uint8_t *buf, uint32_t len)
TSaslTransport(boost::shared_ptr< TTransport > transport)
void encodeInt(uint32_t x, uint8_t *buf, uint32_t offset)
uint8_t * receiveSaslMessage(NegotiationStatus *status, uint32_t *length)
static const int STATUS_BYTES