Write a intercept tcp proxy for Hana DB authentication
Why intercept proxy?
In some cases, we cannot manipulate the behavior of HDB client. For instance, a client may require the hana instance number XX, and use the 3XX15 as the HDB port. Meanwhile, we can deploy the HDB under a port mapping, the HDB port can be anything. But we cannot write the absolute port number in a third party client.
In other case, we don’t want to connect the HDB using user/pass authentication. A good solution is authentication through SAML or kerberos. The problem is the same as case 1, we cannot change the client.
We can either rewrite a client which may be a binary without source code or using a proxy.
What is an tcp proxy?
Normally, the web application works on the HTTP, and it is not hard to write a proxy for it because HTTP is a really common protocol. But for the database interface, we need to understand its wire protocol, normally it is defined by its manufacturer.
After we get the wire protocol, we fetch the tcp packages, analysis them with protocol, and forward them. Because it is hard to find a library on such sql-interface protocal, we need to write everything works on the tcp layer. So it is called tcp proxy.
The tcp packages look as following. The red lines are request and the blue lines are response from server.(Of course the term ‘request/response’ is not correct here)
Implement the transparent proxy
First of all, we can easily implement the transparent proxy.
In python, we use “select” nio to listen the socket. When a client connect to the socket, a forwarding socket to server is created and bind to the client.
import select
...
#server is the socket that proxy deployed.
input_list.append(server)
while 1:
inputready, outputready, exceptready = select.select(input_list, [], [])
for s in inputready:
if s == server:
on_accept()
break
...
def on_accept():
#The socket to real server
forward = socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
#clientsocket
clientsock, clientaddr = server.accept()
#bind
self.channel[clientsock] = forward
self.channel[forward] = clientsock
We can access the opposite socket channel through current channel, without regard to whether it is from server or client. Like:
opposite_socket = channel[coming_socket]
And we can just easily use the python socket api to read data from one socket and send it to another socket.
How to intercept?
Just change the data read from coming socket before sending, certainly.
How to redirect dynamically in hana db case?
How to overwrite the authentication method?
In the code above, there is forward
= socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect((host, port))
The host,port need to be read dynamically through one package. In hana db case, it will send and receive a handshake packed at first, then send the first SCRAMSHA256 request package, we can put the target host,port in this package.
As a result, the procedure should look like this:
We deferred the 1st and 2nd package from client to be sent and read the information such as new instance number. Then we re-send these two requests.
We can also send more message like user or sessionId to do our authentication. What we need to do is follow the wire-protocol and algorithm to re-construct the passing packets.
Notes
In our case, we assume the packet is not too big that will be chunked. If it will be chunked, we need to cache all packets and re-construct, then it is able to be analyzed.
SCRAMSHA256
Here is a simple introduction of SCRAMSHA256 authentication algorithm that hana used.
1. client send a random number (cnonce)
2. server send a random number (snonce) and a salt(salt)
3. client encrypt the password with cnonce, snonce and salt. Send it to server.
4. server confirmation.
Wire-protocol
SCRAMSHA256 Authentication – SAP HANA SQL Command Network Protocol Reference – SAP Library
Proxy in python
#!/usr/bin/python
import socket
import select
import time
import sys
from pyhdb.auth import AuthManager
import logging
# Changing the buffer_size and delay, you can improve the speed and bandwidth.
# But when buffer get to high or delay go too down, you can broke things
buffer_size = 4096
delay = 0.0001
forward_to = ('10.58.5.15', 30015)
log = logging
log.basicConfig(level = log.NOTSET)
#forward_to = ('', 3306)
def matchBytes(source,substring):
sourceLength = len(source)
substringLength = len(substring)
for i in range(0,sourceLength):
flag = True
for k in range(0,substringLength):
if source[i + k] != substring[k]:
flag = False
break
if flag:
return i
return -1
class Forward:
def __init__(self):
self.forward = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
def start(self, host, port):
try:
self.forward.connect((host, port))
return self.forward
except Exception as e:
log.error(e)
return False
class TheServer:
input_list = []
channel = {}
__connection = {}
i = 0
def __init__(self, host, port):
self.server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind((host, port))
self.server.listen(200)
self.countFrom=0
self.countTo=0
def main_loop(self):
self.input_list.append(self.server)
while 1:
time.sleep(delay)
ss = select.select
inputready, outputready, exceptready = ss(self.input_list, [], [])
for self.s in inputready:
if self.s == self.server:
self.on_accept()
break
self.data = self.s.recv(buffer_size)
if len(self.data) == 0:
self.on_close()
break
else:
self.on_recv()
def on_accept(self):
#forward = Forward().start(forward_to[0], forward_to[1])
forward = None
self.forward__ = None
clientsock, clientaddr = self.server.accept()
self.clientsock__ = clientsock
self.clientaddr__ = clientaddr
self.countTo = 0
self.count_in = 0
self.count_out = 0
self.input_list.append(clientsock)
if forward:
TheServer.i += 1
log.info(clientaddr, "has connected")
self.input_list.append(clientsock)
self.input_list.append(forward)
self.forward__ = forward
self.clientsock__ = clientsock
self.channel[clientsock] = forward
self.channel[forward] = clientsock
else:
log.error("Can't establish connection with remote server.")
log.error("Closing connection with client side", clientaddr)
#clientsock.close()
def on_close(self):
log.info("%s %s",self.s.getpeername(), "has disconnected")
#remove objects from input_list
self.count_in = 0
self.count_out = 0
self.input_list.remove(self.s)
self.input_list.remove(self.channel[self.s])
out = self.channel[self.s]
# close the connection with client
self.channel[out].close() # equivalent to do self.s.close()
# close the connection with remote server
self.channel[self.s].close()
# delete both objects from channel dict
del self.channel[out]
del self.channel[self.s]
def on_recv(self):
data = self.data
# here we can parse and/or modify the data before send forward
#print(data)
if self.forward__ and self.forward__ == self.s:
self.countTo+=1
log.debug(self.countTo)
self.count_out += 1
else:
self.countTo+=1
log.debug(self.countTo)
self.count_in += 1
if(self.countTo == 1):
#echo
self.clientsock__.send(b"\x01\x00\x00\x04\x01\x00\x00\x00")
log.debug("echo finished")
return
if(self.countTo == 2):
data = self.onInitialRequest(data)
self.tempdata = data;
self.forward__.send(b"\xff\xff\xff\xff\x04\x14\x00\x04\x01\x00\x00\x01\x01\x01")
return
if(self.countTo == 3):
data = self.tempdata
self.forward__.send(data)
return
if(self.countTo == 4):
data = self.onInitialResponse(data)
if(self.countTo == 5):
data = self.onFinalRequest(data)
log.debug("Send data")
c = self.channel[self.s].send(data)
def onInitialRequest(self, data):
data,username,instanceNumber,sessionId = self.replaceUserName(data,False)
instanceNumber=0
forward = Forward().start(forward_to[0], 30015+instanceNumber*100)
clientsock,clientaddr = (self.clientsock__,self.clientaddr__)
if forward:
self.file_in="pkgs/data_forward_"+ str(TheServer.i)
self.file_out="pkgs/data_receive_"+ str(TheServer.i)
TheServer.i += 1
print(clientaddr, "has connected")
#self.input_list.append(clientsock)
self.input_list.append(forward)
self.forward__ = forward
self.channel[clientsock] = forward
self.channel[forward] = clientsock
else:
print("Can't establish connection with remote server.")
print("Closing connection with client side", clientaddr)
method = b"\x53\x43\x52\x41\x4D\x53\x48\x41\x32\x35\x36\x40"
index = matchBytes(data,method)+len(method)
self.cnonce = data[index:index+64]
print(''.join('{:02x}'.format(x) for x in self.cnonce))
return data
def onInitialResponse(self,data):
method = b"\x53\x43\x52\x41\x4D\x53\x48\x41\x32\x35\x36\x44\x02\x00\x10"
index = matchBytes(data,method)+len(method)
self.salt = data[index:index+16]
print(''.join('{:02x}'.format(x) for x in self.salt))
self.snonce = data[index+16+1:index+16+1+48]
print(''.join('{:02x}'.format(x) for x in self.snonce))
return data
def onFinalRequest(self,data):
data = self.replaceUserName(data,True)[0]
manager = AuthManager(None, "SYSTEM", "manager")
manager.client_key = self.cnonce
client_proof = manager.calculate_client_proof(
[self.salt], self.snonce
)
print(''.join('{:02x}'.format(x) for x in client_proof))
method = b"\x53\x43\x52\x41\x4D\x53\x48\x41\x32\x35\x36\x23"
index = matchBytes(data,method) + len(method)
values = bytearray(data)
for i in range(0,35):
values[index+i] = client_proof[i]
return values
def replaceUserName(self,data,append):
data = bytearray(data)
#IAUID
iaUid = b"\x49\x41\x55\x49\x44"
#SCRAMSHA256
method = b"\x0b\x53\x43\x52\x41\x4D\x53\x48\x41\x32\x35\x36"
indexFrom = data.find(iaUid)
indexTo = data.find(method)
values = bytearray(data[indexFrom:indexTo])
print(values)
fields = values.split(b",")
username = fields[1]
instanceNumber = fields[2]
sessionId = fields[3]
difflen = len(values) - len(username)
m = difflen % 8 if append else 0
data[indexFrom - 1] = len(username)
print(data[indexFrom - 11])
print(difflen)
print(data[indexFrom - 11] - difflen)
data[indexFrom - 11] -= difflen
#data[indexFrom - 11 -32] -= difflen
#data[indexFrom - 11 -32 - 20] -= difflen
for x in range(0,m):
data.insert(indexTo + 48,0)
print(''.join('{:02x}'.format(x) for x in data))
f = data.replace(values,username,1)
for x in range(0,difflen-m):
f.append(0)
print(f)
return (f,username.decode("ascii"),int(instanceNumber.decode("ascii")),sessionId.decode("ascii"))
if __name__ == '__main__':
server = TheServer('', 30015)
try:
server.main_loop()
except KeyboardInterrupt:
print("Ctrl C - Stopping server")
sys.exit(1)