Working HTTP tunnel

This commit is contained in:
2025-07-05 06:45:39 +00:00
parent 9fdaf0fc59
commit eea345e93e
24 changed files with 1413 additions and 0 deletions

View File

@ -0,0 +1,435 @@
package dev.thinhha.tunnel_client.service;
import com.fasterxml.jackson.databind.ObjectMapper;
import dev.thinhha.tunnel_client.config.TunnelConfig;
import dev.thinhha.tunnel_client.dto.TunnelRequestDto;
import dev.thinhha.tunnel_client.dto.TunnelResponseDto;
import dev.thinhha.tunnel_client.types.TunnelRequestType;
import lombok.extern.slf4j.Slf4j;
import org.springframework.http.*;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.*;
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
import org.springframework.web.socket.handler.BinaryWebSocketHandler;
import java.io.IOException;
import java.net.URI;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
@Service
@Slf4j
public class TunnelClient extends BinaryWebSocketHandler {
private final TunnelConfig tunnelConfig;
private final RouteResolver routeResolver;
private final HeaderManipulationService headerManipulationService;
private final ObjectMapper objectMapper;
private final RestTemplate restTemplate;
private WebSocketSession session;
private final CountDownLatch connectionLatch = new CountDownLatch(1);
// Track WebSocket connections: wsConnectionId -> target WebSocket session
private final Map<String, WebSocketSession> webSocketConnections = new ConcurrentHashMap<>();
public TunnelClient(TunnelConfig tunnelConfig, RouteResolver routeResolver, HeaderManipulationService headerManipulationService, ObjectMapper objectMapper, RestTemplate restTemplate) {
this.tunnelConfig = tunnelConfig;
this.routeResolver = routeResolver;
this.headerManipulationService = headerManipulationService;
this.objectMapper = objectMapper;
this.restTemplate = restTemplate;
}
public void connect() {
try {
StandardWebSocketClient client = new StandardWebSocketClient();
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
if (!tunnelConfig.getClient().getToken().isEmpty()) {
headers.add("Authorization", "Bearer " + tunnelConfig.getClient().getToken());
}
URI serverUri = URI.create(tunnelConfig.getServer().getUrl() + "/client");
log.info("Connecting to tunnel server at: {}", serverUri);
client.execute(this, headers, serverUri);
connectionLatch.await();
log.info("Connected to tunnel server as client: {}", tunnelConfig.getClient().getName());
} catch (Exception e) {
log.error("Failed to connect to tunnel server", e);
}
}
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
log.info("WebSocket connection established with session: {}", session.getId());
connectionLatch.countDown();
}
@Override
public void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
try {
byte[] payload = message.getPayload().array();
TunnelRequestDto request = objectMapper.readValue(payload, TunnelRequestDto.class);
log.info("Received tunnel request: {} {} {} {}", request.getRequestId(), request.getType(), request.getMethod(), request.getPath());
TunnelResponseDto response = switch (request.getType()) {
case HTTP -> handleHttpTunnelRequest(request);
case WS_CONNECT -> handleWebSocketConnect(request);
case WS_MESSAGE -> handleWebSocketMessage(request);
case WS_CLOSE -> handleWebSocketClose(request);
};
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
session.sendMessage(new BinaryMessage(responseBytes));
} catch (Exception e) {
log.error("Error processing tunnel request", e);
}
}
private TunnelResponseDto handleHttpTunnelRequest(TunnelRequestDto request) {
try {
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
String fullUrl = targetUrl + request.getPath();
log.info("Forwarding request {} {} to: {}", request.getMethod(), request.getPath(), fullUrl);
HttpHeaders headers = new HttpHeaders();
if (request.getHeaders() != null) {
request.getHeaders().forEach((key, value) -> {
// Handle Content-Type specifically to ensure proper parsing
if ("Content-Type".equalsIgnoreCase(key)) {
headers.set(key, value);
log.debug("Setting Content-Type: {}", value);
} else if ("Content-Length".equalsIgnoreCase(key)) {
// Skip Content-Length as RestTemplate will set it automatically
log.debug("Skipping Content-Length header (will be set automatically)");
} else {
headers.add(key, value);
}
});
}
// Create HTTP entity with proper body handling
final HttpEntity<?> httpEntity = createHttpEntity(request, headers);
// Use byte[] to handle all response types properly
ResponseEntity<byte[]> response = restTemplate.exchange(
fullUrl,
HttpMethod.valueOf(request.getMethod().name()),
httpEntity,
byte[].class
);
// Extract and process response headers
Map<String, String> responseHeaders = extractAndProcessHeaders(response, request.getPath());
int statusCode = response.getStatusCode().value();
log.info("Target service responded: {} {} -> {} {}",
request.getMethod(), request.getPath(), statusCode,
getStatusCodeDescription(statusCode));
TunnelResponseDto tunnelResponse = new TunnelResponseDto();
tunnelResponse.setRequestId(request.getRequestId());
tunnelResponse.setType(TunnelRequestType.HTTP);
tunnelResponse.setStatusCode(statusCode);
tunnelResponse.setHeaders(responseHeaders);
tunnelResponse.setBody(response.getBody());
return tunnelResponse;
} catch (Exception e) {
log.error("Error forwarding request to target service", e);
Map<String, String> errorHeaders = new HashMap<>();
errorHeaders.put("Content-Type", "text/plain");
TunnelResponseDto errorResponse = new TunnelResponseDto();
errorResponse.setRequestId(request.getRequestId());
errorResponse.setType(TunnelRequestType.HTTP);
errorResponse.setStatusCode(500);
errorResponse.setHeaders(errorHeaders);
errorResponse.setBody(("Internal Server Error: " + e.getMessage()).getBytes());
return errorResponse;
}
}
@Override
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
log.error("Transport error: {}", exception.getMessage());
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
log.info("Connection closed with status: {}", closeStatus);
this.session = null;
}
@Override
public boolean supportsPartialMessages() {
return false;
}
private String getStatusCodeDescription(int statusCode) {
return switch (statusCode / 100) {
case 2 -> "Success";
case 3 -> "Redirection";
case 4 -> "Client Error";
case 5 -> "Server Error";
default -> "Unknown";
};
}
private HttpEntity<?> createHttpEntity(TunnelRequestDto request, HttpHeaders headers) {
if (request.getBody() != null && request.getBody().length > 0) {
String contentType = headers.getFirst("Content-Type");
if (contentType != null && contentType.startsWith("application/json")) {
// For JSON, convert bytes to string for proper handling
String jsonBody = new String(request.getBody());
return new HttpEntity<>(jsonBody, headers);
} else if (contentType != null && contentType.startsWith("application/x-www-form-urlencoded")) {
// For form data, convert bytes to string
String formBody = new String(request.getBody());
return new HttpEntity<>(formBody, headers);
} else {
// For binary data or other content types, use byte array
return new HttpEntity<>(request.getBody(), headers);
}
} else {
return new HttpEntity<>(headers);
}
}
private Map<String, String> extractAndProcessHeaders(ResponseEntity<byte[]> response, String path) {
Map<String, String> responseHeaders = new HashMap<>();
response.getHeaders().forEach((key, values) -> {
if (!values.isEmpty()) {
responseHeaders.put(key, values.get(0));
}
});
// Apply header manipulation rules
return headerManipulationService.processResponseHeaders(path, responseHeaders);
}
private TunnelResponseDto handleWebSocketConnect(TunnelRequestDto request) {
try {
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
String wsUrl = targetUrl.replace("http://", "ws://").replace("https://", "wss://") + request.getPath();
log.info("Establishing WebSocket connection to: {}", wsUrl);
StandardWebSocketClient client = new StandardWebSocketClient();
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
if (request.getHeaders() != null) {
request.getHeaders().forEach(headers::add);
}
// Create handler for target WebSocket
WebSocketHandler targetHandler = new WebSocketHandler() {
@Override
public void afterConnectionEstablished(WebSocketSession targetSession) throws Exception {
webSocketConnections.put(request.getWsConnectionId(), targetSession);
log.info("WebSocket connection established: {}", request.getWsConnectionId());
// Send success response back to tunnel server
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_CONNECT);
response.setStatusCode(101); // WebSocket upgrade
response.setHeaders(new HashMap<>());
response.setBody(new byte[0]);
response.setWsConnectionId(request.getWsConnectionId());
response.setWsConnectionEstablished(true);
try {
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
session.sendMessage(new BinaryMessage(responseBytes));
} catch (Exception e) {
log.error("Error sending WebSocket connect response", e);
}
}
@Override
public void handleMessage(WebSocketSession targetSession, WebSocketMessage<?> message) throws Exception {
// Forward message from target to tunnel server
String messageType;
byte[] messageBody;
if (message instanceof TextMessage textMsg) {
messageType = "TEXT";
messageBody = textMsg.getPayload().getBytes();
} else if (message instanceof BinaryMessage binaryMsg) {
messageType = "BINARY";
messageBody = binaryMsg.getPayload().array();
} else {
return; // Unknown message type
}
TunnelResponseDto response = new TunnelResponseDto(
java.util.UUID.randomUUID().toString(),
TunnelRequestType.WS_MESSAGE,
200,
new HashMap<>(),
messageBody,
request.getWsConnectionId(),
messageType,
false
);
try {
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
session.sendMessage(new BinaryMessage(responseBytes));
} catch (Exception e) {
log.error("Error forwarding WebSocket message", e);
}
}
@Override
public void afterConnectionClosed(WebSocketSession targetSession, CloseStatus closeStatus) throws Exception {
webSocketConnections.remove(request.getWsConnectionId());
log.info("WebSocket connection closed: {}", request.getWsConnectionId());
// Notify tunnel server of connection close
TunnelResponseDto response = new TunnelResponseDto(
java.util.UUID.randomUUID().toString(),
TunnelRequestType.WS_CLOSE,
closeStatus.getCode(),
new HashMap<>(),
new byte[0],
request.getWsConnectionId(),
null,
false
);
try {
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
session.sendMessage(new BinaryMessage(responseBytes));
} catch (Exception e) {
log.error("Error sending WebSocket close notification", e);
}
}
@Override
public void handleTransportError(WebSocketSession targetSession, Throwable exception) throws Exception {
log.error("WebSocket transport error: {}", exception.getMessage());
}
@Override
public boolean supportsPartialMessages() {
return false;
}
};
client.execute(targetHandler, headers, URI.create(wsUrl));
// Return immediate response (actual connection established response sent in handler)
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_CONNECT);
response.setStatusCode(102); // Processing
response.setHeaders(new HashMap<>());
response.setBody(new byte[0]);
response.setWsConnectionId(request.getWsConnectionId());
response.setWsConnectionEstablished(false);
return response;
} catch (Exception e) {
log.error("Error establishing WebSocket connection", e);
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_CONNECT);
response.setStatusCode(500);
response.setHeaders(new HashMap<>());
response.setBody(("WebSocket connection failed: " + e.getMessage()).getBytes());
response.setWsConnectionId(request.getWsConnectionId());
response.setWsConnectionEstablished(false);
return response;
}
}
private TunnelResponseDto handleWebSocketMessage(TunnelRequestDto request) {
try {
WebSocketSession targetSession = webSocketConnections.get(request.getWsConnectionId());
if (targetSession == null || !targetSession.isOpen()) {
log.warn("WebSocket connection not found or closed: {}", request.getWsConnectionId());
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_MESSAGE);
response.setStatusCode(404);
response.setHeaders(new HashMap<>());
response.setBody("WebSocket connection not found".getBytes());
response.setWsConnectionId(request.getWsConnectionId());
return response;
}
// Forward message to target
if ("TEXT".equals(request.getWsMessageType())) {
String textPayload = new String(request.getBody());
targetSession.sendMessage(new TextMessage(textPayload));
} else {
targetSession.sendMessage(new BinaryMessage(request.getBody()));
}
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_MESSAGE);
response.setStatusCode(200);
response.setHeaders(new HashMap<>());
response.setBody(new byte[0]);
response.setWsConnectionId(request.getWsConnectionId());
return response;
} catch (Exception e) {
log.error("Error handling WebSocket message", e);
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_MESSAGE);
response.setStatusCode(500);
response.setHeaders(new HashMap<>());
response.setBody(("WebSocket message failed: " + e.getMessage()).getBytes());
response.setWsConnectionId(request.getWsConnectionId());
return response;
}
}
private TunnelResponseDto handleWebSocketClose(TunnelRequestDto request) {
try {
WebSocketSession targetSession = webSocketConnections.remove(request.getWsConnectionId());
if (targetSession != null && targetSession.isOpen()) {
targetSession.close();
log.info("Closed WebSocket connection: {}", request.getWsConnectionId());
}
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_CLOSE);
response.setStatusCode(200);
response.setHeaders(new HashMap<>());
response.setBody(new byte[0]);
response.setWsConnectionId(request.getWsConnectionId());
return response;
} catch (Exception e) {
log.error("Error closing WebSocket connection", e);
TunnelResponseDto response = new TunnelResponseDto();
response.setRequestId(request.getRequestId());
response.setType(TunnelRequestType.WS_CLOSE);
response.setStatusCode(500);
response.setHeaders(new HashMap<>());
response.setBody(new byte[0]);
response.setWsConnectionId(request.getWsConnectionId());
return response;
}
}
}