package dev.thinhha.tunnel_client.service; import com.fasterxml.jackson.databind.ObjectMapper; import dev.thinhha.tunnel_client.config.TunnelConfig; import dev.thinhha.tunnel_client.dto.TunnelRequestDto; import dev.thinhha.tunnel_client.dto.TunnelResponseDto; import dev.thinhha.tunnel_client.types.TunnelRequestType; import lombok.extern.slf4j.Slf4j; import org.springframework.http.*; import org.springframework.stereotype.Service; import org.springframework.web.client.RestTemplate; import org.springframework.web.socket.*; import org.springframework.web.socket.client.standard.StandardWebSocketClient; import org.springframework.web.socket.handler.BinaryWebSocketHandler; import java.io.IOException; import java.net.URI; import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; @Service @Slf4j public class TunnelClient extends BinaryWebSocketHandler { private final TunnelConfig tunnelConfig; private final RouteResolver routeResolver; private final HeaderManipulationService headerManipulationService; private final ObjectMapper objectMapper; private final RestTemplate restTemplate; private WebSocketSession session; private final CountDownLatch connectionLatch = new CountDownLatch(1); // Track WebSocket connections: wsConnectionId -> target WebSocket session private final Map webSocketConnections = new ConcurrentHashMap<>(); public TunnelClient(TunnelConfig tunnelConfig, RouteResolver routeResolver, HeaderManipulationService headerManipulationService, ObjectMapper objectMapper, RestTemplate restTemplate) { this.tunnelConfig = tunnelConfig; this.routeResolver = routeResolver; this.headerManipulationService = headerManipulationService; this.objectMapper = objectMapper; this.restTemplate = restTemplate; } public void connect() { try { StandardWebSocketClient client = new StandardWebSocketClient(); WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); if (!tunnelConfig.getClient().getToken().isEmpty()) { headers.add("Authorization", "Bearer " + tunnelConfig.getClient().getToken()); } URI serverUri = URI.create(tunnelConfig.getServer().getUrl() + "/client"); log.info("Connecting to tunnel server at: {}", serverUri); client.execute(this, headers, serverUri); connectionLatch.await(); log.info("Connected to tunnel server as client: {}", tunnelConfig.getClient().getName()); } catch (Exception e) { log.error("Failed to connect to tunnel server", e); } } @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.session = session; log.info("WebSocket connection established with session: {}", session.getId()); connectionLatch.countDown(); } @Override public void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception { try { byte[] payload = message.getPayload().array(); TunnelRequestDto request = objectMapper.readValue(payload, TunnelRequestDto.class); log.info("Received tunnel request: {} {} {} {}", request.getRequestId(), request.getType(), request.getMethod(), request.getPath()); TunnelResponseDto response = switch (request.getType()) { case HTTP -> handleHttpTunnelRequest(request); case WS_CONNECT -> handleWebSocketConnect(request); case WS_MESSAGE -> handleWebSocketMessage(request); case WS_CLOSE -> handleWebSocketClose(request); }; byte[] responseBytes = objectMapper.writeValueAsBytes(response); session.sendMessage(new BinaryMessage(responseBytes)); } catch (Exception e) { log.error("Error processing tunnel request", e); } } private TunnelResponseDto handleHttpTunnelRequest(TunnelRequestDto request) { try { String targetUrl = routeResolver.resolveTargetUrl(request.getPath()); String fullUrl = targetUrl + request.getPath(); log.info("Forwarding request {} {} to: {}", request.getMethod(), request.getPath(), fullUrl); HttpHeaders headers = new HttpHeaders(); if (request.getHeaders() != null) { request.getHeaders().forEach((key, value) -> { // Handle Content-Type specifically to ensure proper parsing if ("Content-Type".equalsIgnoreCase(key)) { headers.set(key, value); log.debug("Setting Content-Type: {}", value); } else if ("Content-Length".equalsIgnoreCase(key)) { // Skip Content-Length as RestTemplate will set it automatically log.debug("Skipping Content-Length header (will be set automatically)"); } else { headers.add(key, value); } }); } // Create HTTP entity with proper body handling final HttpEntity httpEntity = createHttpEntity(request, headers); // Use byte[] to handle all response types properly ResponseEntity response = restTemplate.exchange( fullUrl, HttpMethod.valueOf(request.getMethod().name()), httpEntity, byte[].class ); // Extract and process response headers Map responseHeaders = extractAndProcessHeaders(response, request.getPath()); int statusCode = response.getStatusCode().value(); log.info("Target service responded: {} {} -> {} {}", request.getMethod(), request.getPath(), statusCode, getStatusCodeDescription(statusCode)); TunnelResponseDto tunnelResponse = new TunnelResponseDto(); tunnelResponse.setRequestId(request.getRequestId()); tunnelResponse.setType(TunnelRequestType.HTTP); tunnelResponse.setStatusCode(statusCode); tunnelResponse.setHeaders(responseHeaders); tunnelResponse.setBody(response.getBody()); return tunnelResponse; } catch (Exception e) { log.error("Error forwarding request to target service", e); Map errorHeaders = new HashMap<>(); errorHeaders.put("Content-Type", "text/plain"); TunnelResponseDto errorResponse = new TunnelResponseDto(); errorResponse.setRequestId(request.getRequestId()); errorResponse.setType(TunnelRequestType.HTTP); errorResponse.setStatusCode(500); errorResponse.setHeaders(errorHeaders); errorResponse.setBody(("Internal Server Error: " + e.getMessage()).getBytes()); return errorResponse; } } @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { log.error("Transport error: {}", exception.getMessage()); } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { log.info("Connection closed with status: {}", closeStatus); this.session = null; } @Override public boolean supportsPartialMessages() { return false; } private String getStatusCodeDescription(int statusCode) { return switch (statusCode / 100) { case 2 -> "Success"; case 3 -> "Redirection"; case 4 -> "Client Error"; case 5 -> "Server Error"; default -> "Unknown"; }; } private HttpEntity createHttpEntity(TunnelRequestDto request, HttpHeaders headers) { if (request.getBody() != null && request.getBody().length > 0) { String contentType = headers.getFirst("Content-Type"); if (contentType != null && contentType.startsWith("application/json")) { // For JSON, convert bytes to string for proper handling String jsonBody = new String(request.getBody()); return new HttpEntity<>(jsonBody, headers); } else if (contentType != null && contentType.startsWith("application/x-www-form-urlencoded")) { // For form data, convert bytes to string String formBody = new String(request.getBody()); return new HttpEntity<>(formBody, headers); } else { // For binary data or other content types, use byte array return new HttpEntity<>(request.getBody(), headers); } } else { return new HttpEntity<>(headers); } } private Map extractAndProcessHeaders(ResponseEntity response, String path) { Map responseHeaders = new HashMap<>(); response.getHeaders().forEach((key, values) -> { if (!values.isEmpty()) { responseHeaders.put(key, values.get(0)); } }); // Apply header manipulation rules return headerManipulationService.processResponseHeaders(path, responseHeaders); } private TunnelResponseDto handleWebSocketConnect(TunnelRequestDto request) { try { String targetUrl = routeResolver.resolveTargetUrl(request.getPath()); String wsUrl = targetUrl.replace("http://", "ws://").replace("https://", "wss://") + request.getPath(); log.info("Establishing WebSocket connection to: {}", wsUrl); StandardWebSocketClient client = new StandardWebSocketClient(); WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); if (request.getHeaders() != null) { request.getHeaders().forEach(headers::add); } // Create handler for target WebSocket WebSocketHandler targetHandler = new WebSocketHandler() { @Override public void afterConnectionEstablished(WebSocketSession targetSession) throws Exception { webSocketConnections.put(request.getWsConnectionId(), targetSession); log.info("WebSocket connection established: {}", request.getWsConnectionId()); // Send success response back to tunnel server TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_CONNECT); response.setStatusCode(101); // WebSocket upgrade response.setHeaders(new HashMap<>()); response.setBody(new byte[0]); response.setWsConnectionId(request.getWsConnectionId()); response.setWsConnectionEstablished(true); try { byte[] responseBytes = objectMapper.writeValueAsBytes(response); session.sendMessage(new BinaryMessage(responseBytes)); } catch (Exception e) { log.error("Error sending WebSocket connect response", e); } } @Override public void handleMessage(WebSocketSession targetSession, WebSocketMessage message) throws Exception { // Forward message from target to tunnel server String messageType; byte[] messageBody; if (message instanceof TextMessage textMsg) { messageType = "TEXT"; messageBody = textMsg.getPayload().getBytes(); } else if (message instanceof BinaryMessage binaryMsg) { messageType = "BINARY"; messageBody = binaryMsg.getPayload().array(); } else { return; // Unknown message type } TunnelResponseDto response = new TunnelResponseDto( java.util.UUID.randomUUID().toString(), TunnelRequestType.WS_MESSAGE, 200, new HashMap<>(), messageBody, request.getWsConnectionId(), messageType, false ); try { byte[] responseBytes = objectMapper.writeValueAsBytes(response); session.sendMessage(new BinaryMessage(responseBytes)); } catch (Exception e) { log.error("Error forwarding WebSocket message", e); } } @Override public void afterConnectionClosed(WebSocketSession targetSession, CloseStatus closeStatus) throws Exception { webSocketConnections.remove(request.getWsConnectionId()); log.info("WebSocket connection closed: {}", request.getWsConnectionId()); // Notify tunnel server of connection close TunnelResponseDto response = new TunnelResponseDto( java.util.UUID.randomUUID().toString(), TunnelRequestType.WS_CLOSE, closeStatus.getCode(), new HashMap<>(), new byte[0], request.getWsConnectionId(), null, false ); try { byte[] responseBytes = objectMapper.writeValueAsBytes(response); session.sendMessage(new BinaryMessage(responseBytes)); } catch (Exception e) { log.error("Error sending WebSocket close notification", e); } } @Override public void handleTransportError(WebSocketSession targetSession, Throwable exception) throws Exception { log.error("WebSocket transport error: {}", exception.getMessage()); } @Override public boolean supportsPartialMessages() { return false; } }; client.execute(targetHandler, headers, URI.create(wsUrl)); // Return immediate response (actual connection established response sent in handler) TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_CONNECT); response.setStatusCode(102); // Processing response.setHeaders(new HashMap<>()); response.setBody(new byte[0]); response.setWsConnectionId(request.getWsConnectionId()); response.setWsConnectionEstablished(false); return response; } catch (Exception e) { log.error("Error establishing WebSocket connection", e); TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_CONNECT); response.setStatusCode(500); response.setHeaders(new HashMap<>()); response.setBody(("WebSocket connection failed: " + e.getMessage()).getBytes()); response.setWsConnectionId(request.getWsConnectionId()); response.setWsConnectionEstablished(false); return response; } } private TunnelResponseDto handleWebSocketMessage(TunnelRequestDto request) { try { WebSocketSession targetSession = webSocketConnections.get(request.getWsConnectionId()); if (targetSession == null || !targetSession.isOpen()) { log.warn("WebSocket connection not found or closed: {}", request.getWsConnectionId()); TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_MESSAGE); response.setStatusCode(404); response.setHeaders(new HashMap<>()); response.setBody("WebSocket connection not found".getBytes()); response.setWsConnectionId(request.getWsConnectionId()); return response; } // Forward message to target if ("TEXT".equals(request.getWsMessageType())) { String textPayload = new String(request.getBody()); targetSession.sendMessage(new TextMessage(textPayload)); } else { targetSession.sendMessage(new BinaryMessage(request.getBody())); } TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_MESSAGE); response.setStatusCode(200); response.setHeaders(new HashMap<>()); response.setBody(new byte[0]); response.setWsConnectionId(request.getWsConnectionId()); return response; } catch (Exception e) { log.error("Error handling WebSocket message", e); TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_MESSAGE); response.setStatusCode(500); response.setHeaders(new HashMap<>()); response.setBody(("WebSocket message failed: " + e.getMessage()).getBytes()); response.setWsConnectionId(request.getWsConnectionId()); return response; } } private TunnelResponseDto handleWebSocketClose(TunnelRequestDto request) { try { WebSocketSession targetSession = webSocketConnections.remove(request.getWsConnectionId()); if (targetSession != null && targetSession.isOpen()) { targetSession.close(); log.info("Closed WebSocket connection: {}", request.getWsConnectionId()); } TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_CLOSE); response.setStatusCode(200); response.setHeaders(new HashMap<>()); response.setBody(new byte[0]); response.setWsConnectionId(request.getWsConnectionId()); return response; } catch (Exception e) { log.error("Error closing WebSocket connection", e); TunnelResponseDto response = new TunnelResponseDto(); response.setRequestId(request.getRequestId()); response.setType(TunnelRequestType.WS_CLOSE); response.setStatusCode(500); response.setHeaders(new HashMap<>()); response.setBody(new byte[0]); response.setWsConnectionId(request.getWsConnectionId()); return response; } } }