Working HTTP tunnel
This commit is contained in:
@ -0,0 +1,435 @@
|
||||
package dev.thinhha.tunnel_client.service;
|
||||
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import dev.thinhha.tunnel_client.config.TunnelConfig;
|
||||
import dev.thinhha.tunnel_client.dto.TunnelRequestDto;
|
||||
import dev.thinhha.tunnel_client.dto.TunnelResponseDto;
|
||||
import dev.thinhha.tunnel_client.types.TunnelRequestType;
|
||||
import lombok.extern.slf4j.Slf4j;
|
||||
import org.springframework.http.*;
|
||||
import org.springframework.stereotype.Service;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
import org.springframework.web.socket.*;
|
||||
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
||||
import org.springframework.web.socket.handler.BinaryWebSocketHandler;
|
||||
|
||||
import java.io.IOException;
|
||||
import java.net.URI;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
@Service
|
||||
@Slf4j
|
||||
public class TunnelClient extends BinaryWebSocketHandler {
|
||||
|
||||
private final TunnelConfig tunnelConfig;
|
||||
private final RouteResolver routeResolver;
|
||||
private final HeaderManipulationService headerManipulationService;
|
||||
private final ObjectMapper objectMapper;
|
||||
private final RestTemplate restTemplate;
|
||||
private WebSocketSession session;
|
||||
private final CountDownLatch connectionLatch = new CountDownLatch(1);
|
||||
|
||||
// Track WebSocket connections: wsConnectionId -> target WebSocket session
|
||||
private final Map<String, WebSocketSession> webSocketConnections = new ConcurrentHashMap<>();
|
||||
|
||||
public TunnelClient(TunnelConfig tunnelConfig, RouteResolver routeResolver, HeaderManipulationService headerManipulationService, ObjectMapper objectMapper, RestTemplate restTemplate) {
|
||||
this.tunnelConfig = tunnelConfig;
|
||||
this.routeResolver = routeResolver;
|
||||
this.headerManipulationService = headerManipulationService;
|
||||
this.objectMapper = objectMapper;
|
||||
this.restTemplate = restTemplate;
|
||||
}
|
||||
|
||||
public void connect() {
|
||||
try {
|
||||
StandardWebSocketClient client = new StandardWebSocketClient();
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
|
||||
if (!tunnelConfig.getClient().getToken().isEmpty()) {
|
||||
headers.add("Authorization", "Bearer " + tunnelConfig.getClient().getToken());
|
||||
}
|
||||
|
||||
URI serverUri = URI.create(tunnelConfig.getServer().getUrl() + "/client");
|
||||
log.info("Connecting to tunnel server at: {}", serverUri);
|
||||
|
||||
client.execute(this, headers, serverUri);
|
||||
|
||||
connectionLatch.await();
|
||||
log.info("Connected to tunnel server as client: {}", tunnelConfig.getClient().getName());
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Failed to connect to tunnel server", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
||||
this.session = session;
|
||||
log.info("WebSocket connection established with session: {}", session.getId());
|
||||
connectionLatch.countDown();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
|
||||
try {
|
||||
byte[] payload = message.getPayload().array();
|
||||
TunnelRequestDto request = objectMapper.readValue(payload, TunnelRequestDto.class);
|
||||
|
||||
log.info("Received tunnel request: {} {} {} {}", request.getRequestId(), request.getType(), request.getMethod(), request.getPath());
|
||||
|
||||
TunnelResponseDto response = switch (request.getType()) {
|
||||
case HTTP -> handleHttpTunnelRequest(request);
|
||||
case WS_CONNECT -> handleWebSocketConnect(request);
|
||||
case WS_MESSAGE -> handleWebSocketMessage(request);
|
||||
case WS_CLOSE -> handleWebSocketClose(request);
|
||||
};
|
||||
|
||||
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
||||
session.sendMessage(new BinaryMessage(responseBytes));
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Error processing tunnel request", e);
|
||||
}
|
||||
}
|
||||
|
||||
private TunnelResponseDto handleHttpTunnelRequest(TunnelRequestDto request) {
|
||||
try {
|
||||
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
|
||||
String fullUrl = targetUrl + request.getPath();
|
||||
|
||||
log.info("Forwarding request {} {} to: {}", request.getMethod(), request.getPath(), fullUrl);
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
if (request.getHeaders() != null) {
|
||||
request.getHeaders().forEach((key, value) -> {
|
||||
// Handle Content-Type specifically to ensure proper parsing
|
||||
if ("Content-Type".equalsIgnoreCase(key)) {
|
||||
headers.set(key, value);
|
||||
log.debug("Setting Content-Type: {}", value);
|
||||
} else if ("Content-Length".equalsIgnoreCase(key)) {
|
||||
// Skip Content-Length as RestTemplate will set it automatically
|
||||
log.debug("Skipping Content-Length header (will be set automatically)");
|
||||
} else {
|
||||
headers.add(key, value);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Create HTTP entity with proper body handling
|
||||
final HttpEntity<?> httpEntity = createHttpEntity(request, headers);
|
||||
|
||||
// Use byte[] to handle all response types properly
|
||||
ResponseEntity<byte[]> response = restTemplate.exchange(
|
||||
fullUrl,
|
||||
HttpMethod.valueOf(request.getMethod().name()),
|
||||
httpEntity,
|
||||
byte[].class
|
||||
);
|
||||
|
||||
// Extract and process response headers
|
||||
Map<String, String> responseHeaders = extractAndProcessHeaders(response, request.getPath());
|
||||
|
||||
int statusCode = response.getStatusCode().value();
|
||||
log.info("Target service responded: {} {} -> {} {}",
|
||||
request.getMethod(), request.getPath(), statusCode,
|
||||
getStatusCodeDescription(statusCode));
|
||||
|
||||
TunnelResponseDto tunnelResponse = new TunnelResponseDto();
|
||||
tunnelResponse.setRequestId(request.getRequestId());
|
||||
tunnelResponse.setType(TunnelRequestType.HTTP);
|
||||
tunnelResponse.setStatusCode(statusCode);
|
||||
tunnelResponse.setHeaders(responseHeaders);
|
||||
tunnelResponse.setBody(response.getBody());
|
||||
|
||||
return tunnelResponse;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Error forwarding request to target service", e);
|
||||
|
||||
Map<String, String> errorHeaders = new HashMap<>();
|
||||
errorHeaders.put("Content-Type", "text/plain");
|
||||
|
||||
TunnelResponseDto errorResponse = new TunnelResponseDto();
|
||||
errorResponse.setRequestId(request.getRequestId());
|
||||
errorResponse.setType(TunnelRequestType.HTTP);
|
||||
errorResponse.setStatusCode(500);
|
||||
errorResponse.setHeaders(errorHeaders);
|
||||
errorResponse.setBody(("Internal Server Error: " + e.getMessage()).getBytes());
|
||||
|
||||
return errorResponse;
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
|
||||
log.error("Transport error: {}", exception.getMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
|
||||
log.info("Connection closed with status: {}", closeStatus);
|
||||
this.session = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsPartialMessages() {
|
||||
return false;
|
||||
}
|
||||
|
||||
private String getStatusCodeDescription(int statusCode) {
|
||||
return switch (statusCode / 100) {
|
||||
case 2 -> "Success";
|
||||
case 3 -> "Redirection";
|
||||
case 4 -> "Client Error";
|
||||
case 5 -> "Server Error";
|
||||
default -> "Unknown";
|
||||
};
|
||||
}
|
||||
|
||||
private HttpEntity<?> createHttpEntity(TunnelRequestDto request, HttpHeaders headers) {
|
||||
if (request.getBody() != null && request.getBody().length > 0) {
|
||||
String contentType = headers.getFirst("Content-Type");
|
||||
if (contentType != null && contentType.startsWith("application/json")) {
|
||||
// For JSON, convert bytes to string for proper handling
|
||||
String jsonBody = new String(request.getBody());
|
||||
return new HttpEntity<>(jsonBody, headers);
|
||||
} else if (contentType != null && contentType.startsWith("application/x-www-form-urlencoded")) {
|
||||
// For form data, convert bytes to string
|
||||
String formBody = new String(request.getBody());
|
||||
return new HttpEntity<>(formBody, headers);
|
||||
} else {
|
||||
// For binary data or other content types, use byte array
|
||||
return new HttpEntity<>(request.getBody(), headers);
|
||||
}
|
||||
} else {
|
||||
return new HttpEntity<>(headers);
|
||||
}
|
||||
}
|
||||
|
||||
private Map<String, String> extractAndProcessHeaders(ResponseEntity<byte[]> response, String path) {
|
||||
Map<String, String> responseHeaders = new HashMap<>();
|
||||
response.getHeaders().forEach((key, values) -> {
|
||||
if (!values.isEmpty()) {
|
||||
responseHeaders.put(key, values.get(0));
|
||||
}
|
||||
});
|
||||
|
||||
// Apply header manipulation rules
|
||||
return headerManipulationService.processResponseHeaders(path, responseHeaders);
|
||||
}
|
||||
|
||||
private TunnelResponseDto handleWebSocketConnect(TunnelRequestDto request) {
|
||||
try {
|
||||
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
|
||||
String wsUrl = targetUrl.replace("http://", "ws://").replace("https://", "wss://") + request.getPath();
|
||||
|
||||
log.info("Establishing WebSocket connection to: {}", wsUrl);
|
||||
|
||||
StandardWebSocketClient client = new StandardWebSocketClient();
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
|
||||
if (request.getHeaders() != null) {
|
||||
request.getHeaders().forEach(headers::add);
|
||||
}
|
||||
|
||||
// Create handler for target WebSocket
|
||||
WebSocketHandler targetHandler = new WebSocketHandler() {
|
||||
@Override
|
||||
public void afterConnectionEstablished(WebSocketSession targetSession) throws Exception {
|
||||
webSocketConnections.put(request.getWsConnectionId(), targetSession);
|
||||
log.info("WebSocket connection established: {}", request.getWsConnectionId());
|
||||
|
||||
// Send success response back to tunnel server
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_CONNECT);
|
||||
response.setStatusCode(101); // WebSocket upgrade
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(new byte[0]);
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
response.setWsConnectionEstablished(true);
|
||||
|
||||
try {
|
||||
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
||||
session.sendMessage(new BinaryMessage(responseBytes));
|
||||
} catch (Exception e) {
|
||||
log.error("Error sending WebSocket connect response", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleMessage(WebSocketSession targetSession, WebSocketMessage<?> message) throws Exception {
|
||||
// Forward message from target to tunnel server
|
||||
String messageType;
|
||||
byte[] messageBody;
|
||||
|
||||
if (message instanceof TextMessage textMsg) {
|
||||
messageType = "TEXT";
|
||||
messageBody = textMsg.getPayload().getBytes();
|
||||
} else if (message instanceof BinaryMessage binaryMsg) {
|
||||
messageType = "BINARY";
|
||||
messageBody = binaryMsg.getPayload().array();
|
||||
} else {
|
||||
return; // Unknown message type
|
||||
}
|
||||
|
||||
TunnelResponseDto response = new TunnelResponseDto(
|
||||
java.util.UUID.randomUUID().toString(),
|
||||
TunnelRequestType.WS_MESSAGE,
|
||||
200,
|
||||
new HashMap<>(),
|
||||
messageBody,
|
||||
request.getWsConnectionId(),
|
||||
messageType,
|
||||
false
|
||||
);
|
||||
|
||||
try {
|
||||
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
||||
session.sendMessage(new BinaryMessage(responseBytes));
|
||||
} catch (Exception e) {
|
||||
log.error("Error forwarding WebSocket message", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void afterConnectionClosed(WebSocketSession targetSession, CloseStatus closeStatus) throws Exception {
|
||||
webSocketConnections.remove(request.getWsConnectionId());
|
||||
log.info("WebSocket connection closed: {}", request.getWsConnectionId());
|
||||
|
||||
// Notify tunnel server of connection close
|
||||
TunnelResponseDto response = new TunnelResponseDto(
|
||||
java.util.UUID.randomUUID().toString(),
|
||||
TunnelRequestType.WS_CLOSE,
|
||||
closeStatus.getCode(),
|
||||
new HashMap<>(),
|
||||
new byte[0],
|
||||
request.getWsConnectionId(),
|
||||
null,
|
||||
false
|
||||
);
|
||||
|
||||
try {
|
||||
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
||||
session.sendMessage(new BinaryMessage(responseBytes));
|
||||
} catch (Exception e) {
|
||||
log.error("Error sending WebSocket close notification", e);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void handleTransportError(WebSocketSession targetSession, Throwable exception) throws Exception {
|
||||
log.error("WebSocket transport error: {}", exception.getMessage());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean supportsPartialMessages() {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
client.execute(targetHandler, headers, URI.create(wsUrl));
|
||||
|
||||
// Return immediate response (actual connection established response sent in handler)
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_CONNECT);
|
||||
response.setStatusCode(102); // Processing
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(new byte[0]);
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
response.setWsConnectionEstablished(false);
|
||||
return response;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Error establishing WebSocket connection", e);
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_CONNECT);
|
||||
response.setStatusCode(500);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(("WebSocket connection failed: " + e.getMessage()).getBytes());
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
response.setWsConnectionEstablished(false);
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
private TunnelResponseDto handleWebSocketMessage(TunnelRequestDto request) {
|
||||
try {
|
||||
WebSocketSession targetSession = webSocketConnections.get(request.getWsConnectionId());
|
||||
if (targetSession == null || !targetSession.isOpen()) {
|
||||
log.warn("WebSocket connection not found or closed: {}", request.getWsConnectionId());
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_MESSAGE);
|
||||
response.setStatusCode(404);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody("WebSocket connection not found".getBytes());
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
return response;
|
||||
}
|
||||
|
||||
// Forward message to target
|
||||
if ("TEXT".equals(request.getWsMessageType())) {
|
||||
String textPayload = new String(request.getBody());
|
||||
targetSession.sendMessage(new TextMessage(textPayload));
|
||||
} else {
|
||||
targetSession.sendMessage(new BinaryMessage(request.getBody()));
|
||||
}
|
||||
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_MESSAGE);
|
||||
response.setStatusCode(200);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(new byte[0]);
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
return response;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Error handling WebSocket message", e);
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_MESSAGE);
|
||||
response.setStatusCode(500);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(("WebSocket message failed: " + e.getMessage()).getBytes());
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
return response;
|
||||
}
|
||||
}
|
||||
|
||||
private TunnelResponseDto handleWebSocketClose(TunnelRequestDto request) {
|
||||
try {
|
||||
WebSocketSession targetSession = webSocketConnections.remove(request.getWsConnectionId());
|
||||
if (targetSession != null && targetSession.isOpen()) {
|
||||
targetSession.close();
|
||||
log.info("Closed WebSocket connection: {}", request.getWsConnectionId());
|
||||
}
|
||||
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_CLOSE);
|
||||
response.setStatusCode(200);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(new byte[0]);
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
return response;
|
||||
|
||||
} catch (Exception e) {
|
||||
log.error("Error closing WebSocket connection", e);
|
||||
TunnelResponseDto response = new TunnelResponseDto();
|
||||
response.setRequestId(request.getRequestId());
|
||||
response.setType(TunnelRequestType.WS_CLOSE);
|
||||
response.setStatusCode(500);
|
||||
response.setHeaders(new HashMap<>());
|
||||
response.setBody(new byte[0]);
|
||||
response.setWsConnectionId(request.getWsConnectionId());
|
||||
return response;
|
||||
}
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user