diff --git a/src/main/java/dev/thinhha/tunnel_client/TunnelClientApplication.java b/src/main/java/dev/thinhha/tunnel_client/TunnelClientApplication.java index 086c7f5..cdd9571 100644 --- a/src/main/java/dev/thinhha/tunnel_client/TunnelClientApplication.java +++ b/src/main/java/dev/thinhha/tunnel_client/TunnelClientApplication.java @@ -20,6 +20,7 @@ public class TunnelClientApplication { Runtime.getRuntime().addShutdownHook(new Thread(() -> { System.out.println("Shutting down tunnel client..."); + tunnelClient.stop(); })); while (true) { diff --git a/src/main/java/dev/thinhha/tunnel_client/service/TunnelClient.java b/src/main/java/dev/thinhha/tunnel_client/service/TunnelClient.java index dd4e465..94299d3 100644 --- a/src/main/java/dev/thinhha/tunnel_client/service/TunnelClient.java +++ b/src/main/java/dev/thinhha/tunnel_client/service/TunnelClient.java @@ -17,8 +17,9 @@ import java.io.IOException; import java.net.URI; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; @Service @Slf4j @@ -35,6 +36,19 @@ public class TunnelClient extends BinaryWebSocketHandler { // Track WebSocket connections: wsConnectionId -> target WebSocket session private final Map webSocketConnections = new ConcurrentHashMap<>(); + // Retry mechanism fields + private final AtomicInteger retryCount = new AtomicInteger(0); + private final AtomicBoolean isConnecting = new AtomicBoolean(false); + private final AtomicBoolean shouldStop = new AtomicBoolean(false); + private final ScheduledExecutorService retryScheduler = Executors.newSingleThreadScheduledExecutor(); + + // Retry configuration + private static final int MAX_RETRY_ATTEMPTS = 10; + private static final long INITIAL_RETRY_DELAY_MS = 1500; + private static final long MAX_RETRY_DELAY_MS = 60000; + private static final double BACKOFF_MULTIPLIER = 2.0; + private static final long CONNECTION_TIMEOUT_MS = 5000; + public TunnelClient(TunnelConfig tunnelConfig, RouteResolver routeResolver, HeaderManipulationService headerManipulationService, ObjectMapper objectMapper, RestTemplate restTemplate) { this.tunnelConfig = tunnelConfig; this.routeResolver = routeResolver; @@ -44,6 +58,24 @@ public class TunnelClient extends BinaryWebSocketHandler { } public void connect() { + shouldStop.set(false); + retryCount.set(0); + connectWithRetry(); + } + + private void connectWithRetry() { + if (shouldStop.get()) { + log.info("Connection stopped by user"); + return; + } + + if (isConnecting.get()) { + log.debug("Connection already in progress"); + return; + } + + isConnecting.set(true); + try { StandardWebSocketClient client = new StandardWebSocketClient(); WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); @@ -53,21 +85,64 @@ public class TunnelClient extends BinaryWebSocketHandler { } URI serverUri = URI.create(tunnelConfig.getServer().getUrl() + "/client"); - log.info("Connecting to tunnel server at: {}", serverUri); + log.info("Connecting to tunnel server at: {} (attempt {}/{})", serverUri, retryCount.get() + 1, MAX_RETRY_ATTEMPTS); client.execute(this, headers, serverUri); - connectionLatch.await(); + boolean connected = connectionLatch.await(CONNECTION_TIMEOUT_MS, TimeUnit.MILLISECONDS); + if (!connected) { + throw new RuntimeException("Connection timeout after " + CONNECTION_TIMEOUT_MS + "ms"); + } log.info("Connected to tunnel server as client: {}", tunnelConfig.getClient().getName()); + retryCount.set(0); } catch (Exception e) { - log.error("Failed to connect to tunnel server", e); + log.error("Failed to connect to tunnel server (attempt {}/{}): {}", retryCount.get() + 1, MAX_RETRY_ATTEMPTS, e.getMessage()); + scheduleRetry(); + } finally { + isConnecting.set(false); } } + private void scheduleRetry() { + int currentRetryCount = retryCount.incrementAndGet(); + + if (currentRetryCount >= MAX_RETRY_ATTEMPTS) { + log.error("Max retry attempts ({}) reached. Stopping connection attempts.", MAX_RETRY_ATTEMPTS); + return; + } + + long delayMs = calculateRetryDelay(currentRetryCount); + log.info("Scheduling retry {} in {} ms", currentRetryCount, delayMs); + + retryScheduler.schedule(() -> { + log.info("Attempting to reconnect (retry {}/{})", currentRetryCount, MAX_RETRY_ATTEMPTS); + connectWithRetry(); + }, delayMs, TimeUnit.MILLISECONDS); + } + + private long calculateRetryDelay(int attemptCount) { + long delay = (long) (INITIAL_RETRY_DELAY_MS * Math.pow(BACKOFF_MULTIPLIER, attemptCount - 1)); + return Math.min(delay, MAX_RETRY_DELAY_MS); + } + + public void stop() { + shouldStop.set(true); + if (session != null && session.isOpen()) { + try { + session.close(); + } catch (Exception e) { + log.warn("Error closing WebSocket session: {}", e.getMessage()); + } + } + retryScheduler.shutdown(); + log.info("Tunnel client stopped"); + } + @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { this.session = session; + retryCount.set(0); log.info("WebSocket connection established with session: {}", session.getId()); connectionLatch.countDown(); } @@ -166,12 +241,21 @@ public class TunnelClient extends BinaryWebSocketHandler { @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { log.error("Transport error: {}", exception.getMessage()); + if (!shouldStop.get()) { + log.info("Attempting to reconnect after transport error"); + scheduleRetry(); + } } @Override public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { log.info("Connection closed with status: {}", closeStatus); this.session = null; + + if (!shouldStop.get() && !closeStatus.equals(CloseStatus.NORMAL)) { + log.info("Connection closed unexpectedly, attempting to reconnect"); + scheduleRetry(); + } } @Override