435 lines
20 KiB
Java
435 lines
20 KiB
Java
package dev.thinhha.tunnel_client.service;
|
|
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
import dev.thinhha.tunnel_client.config.TunnelConfig;
|
|
import dev.thinhha.tunnel_client.dto.TunnelRequestDto;
|
|
import dev.thinhha.tunnel_client.dto.TunnelResponseDto;
|
|
import dev.thinhha.tunnel_client.types.TunnelRequestType;
|
|
import lombok.extern.slf4j.Slf4j;
|
|
import org.springframework.http.*;
|
|
import org.springframework.stereotype.Service;
|
|
import org.springframework.web.client.RestTemplate;
|
|
import org.springframework.web.socket.*;
|
|
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
|
|
import org.springframework.web.socket.handler.BinaryWebSocketHandler;
|
|
|
|
import java.io.IOException;
|
|
import java.net.URI;
|
|
import java.util.HashMap;
|
|
import java.util.Map;
|
|
import java.util.concurrent.ConcurrentHashMap;
|
|
import java.util.concurrent.CountDownLatch;
|
|
|
|
@Service
|
|
@Slf4j
|
|
public class TunnelClient extends BinaryWebSocketHandler {
|
|
|
|
private final TunnelConfig tunnelConfig;
|
|
private final RouteResolver routeResolver;
|
|
private final HeaderManipulationService headerManipulationService;
|
|
private final ObjectMapper objectMapper;
|
|
private final RestTemplate restTemplate;
|
|
private WebSocketSession session;
|
|
private final CountDownLatch connectionLatch = new CountDownLatch(1);
|
|
|
|
// Track WebSocket connections: wsConnectionId -> target WebSocket session
|
|
private final Map<String, WebSocketSession> webSocketConnections = new ConcurrentHashMap<>();
|
|
|
|
public TunnelClient(TunnelConfig tunnelConfig, RouteResolver routeResolver, HeaderManipulationService headerManipulationService, ObjectMapper objectMapper, RestTemplate restTemplate) {
|
|
this.tunnelConfig = tunnelConfig;
|
|
this.routeResolver = routeResolver;
|
|
this.headerManipulationService = headerManipulationService;
|
|
this.objectMapper = objectMapper;
|
|
this.restTemplate = restTemplate;
|
|
}
|
|
|
|
public void connect() {
|
|
try {
|
|
StandardWebSocketClient client = new StandardWebSocketClient();
|
|
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
|
|
|
if (!tunnelConfig.getClient().getToken().isEmpty()) {
|
|
headers.add("Authorization", "Bearer " + tunnelConfig.getClient().getToken());
|
|
}
|
|
|
|
URI serverUri = URI.create(tunnelConfig.getServer().getUrl() + "/client");
|
|
log.info("Connecting to tunnel server at: {}", serverUri);
|
|
|
|
client.execute(this, headers, serverUri);
|
|
|
|
connectionLatch.await();
|
|
log.info("Connected to tunnel server as client: {}", tunnelConfig.getClient().getName());
|
|
|
|
} catch (Exception e) {
|
|
log.error("Failed to connect to tunnel server", e);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
|
|
this.session = session;
|
|
log.info("WebSocket connection established with session: {}", session.getId());
|
|
connectionLatch.countDown();
|
|
}
|
|
|
|
@Override
|
|
public void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
|
|
try {
|
|
byte[] payload = message.getPayload().array();
|
|
TunnelRequestDto request = objectMapper.readValue(payload, TunnelRequestDto.class);
|
|
|
|
log.info("Received tunnel request: {} {} {} {}", request.getRequestId(), request.getType(), request.getMethod(), request.getPath());
|
|
|
|
TunnelResponseDto response = switch (request.getType()) {
|
|
case HTTP -> handleHttpTunnelRequest(request);
|
|
case WS_CONNECT -> handleWebSocketConnect(request);
|
|
case WS_MESSAGE -> handleWebSocketMessage(request);
|
|
case WS_CLOSE -> handleWebSocketClose(request);
|
|
};
|
|
|
|
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
|
session.sendMessage(new BinaryMessage(responseBytes));
|
|
|
|
} catch (Exception e) {
|
|
log.error("Error processing tunnel request", e);
|
|
}
|
|
}
|
|
|
|
private TunnelResponseDto handleHttpTunnelRequest(TunnelRequestDto request) {
|
|
try {
|
|
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
|
|
String fullUrl = targetUrl + request.getPath();
|
|
|
|
log.info("Forwarding request {} {} to: {}", request.getMethod(), request.getPath(), fullUrl);
|
|
|
|
HttpHeaders headers = new HttpHeaders();
|
|
if (request.getHeaders() != null) {
|
|
request.getHeaders().forEach((key, value) -> {
|
|
// Handle Content-Type specifically to ensure proper parsing
|
|
if ("Content-Type".equalsIgnoreCase(key)) {
|
|
headers.set(key, value);
|
|
log.debug("Setting Content-Type: {}", value);
|
|
} else if ("Content-Length".equalsIgnoreCase(key)) {
|
|
// Skip Content-Length as RestTemplate will set it automatically
|
|
log.debug("Skipping Content-Length header (will be set automatically)");
|
|
} else {
|
|
headers.add(key, value);
|
|
}
|
|
});
|
|
}
|
|
|
|
// Create HTTP entity with proper body handling
|
|
final HttpEntity<?> httpEntity = createHttpEntity(request, headers);
|
|
|
|
// Use byte[] to handle all response types properly
|
|
ResponseEntity<byte[]> response = restTemplate.exchange(
|
|
fullUrl,
|
|
HttpMethod.valueOf(request.getMethod().name()),
|
|
httpEntity,
|
|
byte[].class
|
|
);
|
|
|
|
// Extract and process response headers
|
|
Map<String, String> responseHeaders = extractAndProcessHeaders(response, request.getPath());
|
|
|
|
int statusCode = response.getStatusCode().value();
|
|
log.info("Target service responded: {} {} -> {} {}",
|
|
request.getMethod(), request.getPath(), statusCode,
|
|
getStatusCodeDescription(statusCode));
|
|
|
|
TunnelResponseDto tunnelResponse = new TunnelResponseDto();
|
|
tunnelResponse.setRequestId(request.getRequestId());
|
|
tunnelResponse.setType(TunnelRequestType.HTTP);
|
|
tunnelResponse.setStatusCode(statusCode);
|
|
tunnelResponse.setHeaders(responseHeaders);
|
|
tunnelResponse.setBody(response.getBody());
|
|
|
|
return tunnelResponse;
|
|
|
|
} catch (Exception e) {
|
|
log.error("Error forwarding request to target service", e);
|
|
|
|
Map<String, String> errorHeaders = new HashMap<>();
|
|
errorHeaders.put("Content-Type", "text/plain");
|
|
|
|
TunnelResponseDto errorResponse = new TunnelResponseDto();
|
|
errorResponse.setRequestId(request.getRequestId());
|
|
errorResponse.setType(TunnelRequestType.HTTP);
|
|
errorResponse.setStatusCode(500);
|
|
errorResponse.setHeaders(errorHeaders);
|
|
errorResponse.setBody(("Internal Server Error: " + e.getMessage()).getBytes());
|
|
|
|
return errorResponse;
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception {
|
|
log.error("Transport error: {}", exception.getMessage());
|
|
}
|
|
|
|
@Override
|
|
public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception {
|
|
log.info("Connection closed with status: {}", closeStatus);
|
|
this.session = null;
|
|
}
|
|
|
|
@Override
|
|
public boolean supportsPartialMessages() {
|
|
return false;
|
|
}
|
|
|
|
private String getStatusCodeDescription(int statusCode) {
|
|
return switch (statusCode / 100) {
|
|
case 2 -> "Success";
|
|
case 3 -> "Redirection";
|
|
case 4 -> "Client Error";
|
|
case 5 -> "Server Error";
|
|
default -> "Unknown";
|
|
};
|
|
}
|
|
|
|
private HttpEntity<?> createHttpEntity(TunnelRequestDto request, HttpHeaders headers) {
|
|
if (request.getBody() != null && request.getBody().length > 0) {
|
|
String contentType = headers.getFirst("Content-Type");
|
|
if (contentType != null && contentType.startsWith("application/json")) {
|
|
// For JSON, convert bytes to string for proper handling
|
|
String jsonBody = new String(request.getBody());
|
|
return new HttpEntity<>(jsonBody, headers);
|
|
} else if (contentType != null && contentType.startsWith("application/x-www-form-urlencoded")) {
|
|
// For form data, convert bytes to string
|
|
String formBody = new String(request.getBody());
|
|
return new HttpEntity<>(formBody, headers);
|
|
} else {
|
|
// For binary data or other content types, use byte array
|
|
return new HttpEntity<>(request.getBody(), headers);
|
|
}
|
|
} else {
|
|
return new HttpEntity<>(headers);
|
|
}
|
|
}
|
|
|
|
private Map<String, String> extractAndProcessHeaders(ResponseEntity<byte[]> response, String path) {
|
|
Map<String, String> responseHeaders = new HashMap<>();
|
|
response.getHeaders().forEach((key, values) -> {
|
|
if (!values.isEmpty()) {
|
|
responseHeaders.put(key, values.get(0));
|
|
}
|
|
});
|
|
|
|
// Apply header manipulation rules
|
|
return headerManipulationService.processResponseHeaders(path, responseHeaders);
|
|
}
|
|
|
|
private TunnelResponseDto handleWebSocketConnect(TunnelRequestDto request) {
|
|
try {
|
|
String targetUrl = routeResolver.resolveTargetUrl(request.getPath());
|
|
String wsUrl = targetUrl.replace("http://", "ws://").replace("https://", "wss://") + request.getPath();
|
|
|
|
log.info("Establishing WebSocket connection to: {}", wsUrl);
|
|
|
|
StandardWebSocketClient client = new StandardWebSocketClient();
|
|
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
|
|
|
if (request.getHeaders() != null) {
|
|
request.getHeaders().forEach(headers::add);
|
|
}
|
|
|
|
// Create handler for target WebSocket
|
|
WebSocketHandler targetHandler = new WebSocketHandler() {
|
|
@Override
|
|
public void afterConnectionEstablished(WebSocketSession targetSession) throws Exception {
|
|
webSocketConnections.put(request.getWsConnectionId(), targetSession);
|
|
log.info("WebSocket connection established: {}", request.getWsConnectionId());
|
|
|
|
// Send success response back to tunnel server
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_CONNECT);
|
|
response.setStatusCode(101); // WebSocket upgrade
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(new byte[0]);
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
response.setWsConnectionEstablished(true);
|
|
|
|
try {
|
|
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
|
session.sendMessage(new BinaryMessage(responseBytes));
|
|
} catch (Exception e) {
|
|
log.error("Error sending WebSocket connect response", e);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void handleMessage(WebSocketSession targetSession, WebSocketMessage<?> message) throws Exception {
|
|
// Forward message from target to tunnel server
|
|
String messageType;
|
|
byte[] messageBody;
|
|
|
|
if (message instanceof TextMessage textMsg) {
|
|
messageType = "TEXT";
|
|
messageBody = textMsg.getPayload().getBytes();
|
|
} else if (message instanceof BinaryMessage binaryMsg) {
|
|
messageType = "BINARY";
|
|
messageBody = binaryMsg.getPayload().array();
|
|
} else {
|
|
return; // Unknown message type
|
|
}
|
|
|
|
TunnelResponseDto response = new TunnelResponseDto(
|
|
java.util.UUID.randomUUID().toString(),
|
|
TunnelRequestType.WS_MESSAGE,
|
|
200,
|
|
new HashMap<>(),
|
|
messageBody,
|
|
request.getWsConnectionId(),
|
|
messageType,
|
|
false
|
|
);
|
|
|
|
try {
|
|
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
|
session.sendMessage(new BinaryMessage(responseBytes));
|
|
} catch (Exception e) {
|
|
log.error("Error forwarding WebSocket message", e);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void afterConnectionClosed(WebSocketSession targetSession, CloseStatus closeStatus) throws Exception {
|
|
webSocketConnections.remove(request.getWsConnectionId());
|
|
log.info("WebSocket connection closed: {}", request.getWsConnectionId());
|
|
|
|
// Notify tunnel server of connection close
|
|
TunnelResponseDto response = new TunnelResponseDto(
|
|
java.util.UUID.randomUUID().toString(),
|
|
TunnelRequestType.WS_CLOSE,
|
|
closeStatus.getCode(),
|
|
new HashMap<>(),
|
|
new byte[0],
|
|
request.getWsConnectionId(),
|
|
null,
|
|
false
|
|
);
|
|
|
|
try {
|
|
byte[] responseBytes = objectMapper.writeValueAsBytes(response);
|
|
session.sendMessage(new BinaryMessage(responseBytes));
|
|
} catch (Exception e) {
|
|
log.error("Error sending WebSocket close notification", e);
|
|
}
|
|
}
|
|
|
|
@Override
|
|
public void handleTransportError(WebSocketSession targetSession, Throwable exception) throws Exception {
|
|
log.error("WebSocket transport error: {}", exception.getMessage());
|
|
}
|
|
|
|
@Override
|
|
public boolean supportsPartialMessages() {
|
|
return false;
|
|
}
|
|
};
|
|
|
|
client.execute(targetHandler, headers, URI.create(wsUrl));
|
|
|
|
// Return immediate response (actual connection established response sent in handler)
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_CONNECT);
|
|
response.setStatusCode(102); // Processing
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(new byte[0]);
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
response.setWsConnectionEstablished(false);
|
|
return response;
|
|
|
|
} catch (Exception e) {
|
|
log.error("Error establishing WebSocket connection", e);
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_CONNECT);
|
|
response.setStatusCode(500);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(("WebSocket connection failed: " + e.getMessage()).getBytes());
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
response.setWsConnectionEstablished(false);
|
|
return response;
|
|
}
|
|
}
|
|
|
|
private TunnelResponseDto handleWebSocketMessage(TunnelRequestDto request) {
|
|
try {
|
|
WebSocketSession targetSession = webSocketConnections.get(request.getWsConnectionId());
|
|
if (targetSession == null || !targetSession.isOpen()) {
|
|
log.warn("WebSocket connection not found or closed: {}", request.getWsConnectionId());
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_MESSAGE);
|
|
response.setStatusCode(404);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody("WebSocket connection not found".getBytes());
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
return response;
|
|
}
|
|
|
|
// Forward message to target
|
|
if ("TEXT".equals(request.getWsMessageType())) {
|
|
String textPayload = new String(request.getBody());
|
|
targetSession.sendMessage(new TextMessage(textPayload));
|
|
} else {
|
|
targetSession.sendMessage(new BinaryMessage(request.getBody()));
|
|
}
|
|
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_MESSAGE);
|
|
response.setStatusCode(200);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(new byte[0]);
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
return response;
|
|
|
|
} catch (Exception e) {
|
|
log.error("Error handling WebSocket message", e);
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_MESSAGE);
|
|
response.setStatusCode(500);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(("WebSocket message failed: " + e.getMessage()).getBytes());
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
return response;
|
|
}
|
|
}
|
|
|
|
private TunnelResponseDto handleWebSocketClose(TunnelRequestDto request) {
|
|
try {
|
|
WebSocketSession targetSession = webSocketConnections.remove(request.getWsConnectionId());
|
|
if (targetSession != null && targetSession.isOpen()) {
|
|
targetSession.close();
|
|
log.info("Closed WebSocket connection: {}", request.getWsConnectionId());
|
|
}
|
|
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_CLOSE);
|
|
response.setStatusCode(200);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(new byte[0]);
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
return response;
|
|
|
|
} catch (Exception e) {
|
|
log.error("Error closing WebSocket connection", e);
|
|
TunnelResponseDto response = new TunnelResponseDto();
|
|
response.setRequestId(request.getRequestId());
|
|
response.setType(TunnelRequestType.WS_CLOSE);
|
|
response.setStatusCode(500);
|
|
response.setHeaders(new HashMap<>());
|
|
response.setBody(new byte[0]);
|
|
response.setWsConnectionId(request.getWsConnectionId());
|
|
return response;
|
|
}
|
|
}
|
|
} |