implement real gateway path.
Some checks failed
CI / Rust (push) Successful in 25s
CI / Android (push) Failing after 2s

This commit is contained in:
2026-05-31 20:10:11 +03:30
parent 442fad6b05
commit 266cae92ce
7 changed files with 1062 additions and 17 deletions

View File

@@ -17,10 +17,12 @@ import kotlinx.coroutines.cancel
import kotlinx.coroutines.launch
import org.vpnshare.domain.model.GatewayConfig
import org.vpnshare.engine.RustVpnShareEngine
import org.vpnshare.gateway.socks.UsbSocksGateway
class VpnShareGatewayService : Service() {
private val scope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
private val engine = RustVpnShareEngine()
private val socksGateway = UsbSocksGateway()
private lateinit var vpnDetector: VpnDetector
override fun onCreate() {
@@ -45,11 +47,13 @@ class VpnShareGatewayService : Service() {
override fun onBind(intent: Intent?): IBinder? = null
override fun onDestroy() {
socksGateway.stop()
scope.cancel()
super.onDestroy()
}
private fun startGateway() {
socksGateway.start()
val notification = buildNotification()
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
startForeground(

View File

@@ -0,0 +1,328 @@
package org.vpnshare.gateway.socks
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.cancel
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import java.io.EOFException
import java.io.InputStream
import java.io.OutputStream
import java.net.DatagramPacket
import java.net.DatagramSocket
import java.net.InetAddress
import java.net.InetSocketAddress
import java.net.ServerSocket
import java.net.Socket
import java.net.SocketTimeoutException
import java.util.concurrent.atomic.AtomicBoolean
class UsbSocksGateway(
private val port: Int = DEFAULT_PORT
) {
private val running = AtomicBoolean(false)
private var serverSocket: ServerSocket? = null
private var scope: CoroutineScope? = null
private var acceptJob: Job? = null
fun start() {
if (!running.compareAndSet(false, true)) return
val serviceScope = CoroutineScope(SupervisorJob() + Dispatchers.IO)
scope = serviceScope
val server = ServerSocket().apply {
reuseAddress = true
bind(InetSocketAddress(InetAddress.getLoopbackAddress(), port))
}
serverSocket = server
acceptJob = serviceScope.launch {
while (isActive && running.get()) {
val client = runCatching { server.accept() }.getOrNull() ?: break
launch { handleClient(client) }
}
}
}
fun stop() {
running.set(false)
runCatching { serverSocket?.close() }
acceptJob?.cancel()
scope?.cancel()
serverSocket = null
acceptJob = null
scope = null
}
private fun handleClient(client: Socket) {
client.use { clientSocket ->
clientSocket.tcpNoDelay = true
clientSocket.soTimeout = HANDSHAKE_TIMEOUT_MS
val input = clientSocket.getInputStream()
val output = clientSocket.getOutputStream()
runCatching {
negotiate(input, output)
when (val request = readRequest(input)) {
is SocksRequest.TcpConnect -> {
Socket().use { upstream ->
upstream.tcpNoDelay = true
upstream.connect(request.destination, CONNECT_TIMEOUT_MS)
writeSuccess(output, request.destination)
clientSocket.soTimeout = 0
pipeBothWays(clientSocket, upstream)
}
}
SocksRequest.UdpForward -> {
writeSuccess(output, null)
clientSocket.soTimeout = 0
forwardUdpInTcp(input, output)
}
}
}.onFailure {
runCatching { writeFailure(output) }
}
}
}
private fun negotiate(input: InputStream, output: OutputStream) {
val version = input.readRequired()
require(version == SOCKS_VERSION) { "unsupported socks version" }
val methodCount = input.readRequired()
val methods = ByteArray(methodCount)
input.readFully(methods)
require(methods.any { it.toInt() == AUTH_NONE }) { "no supported auth method" }
output.write(byteArrayOf(SOCKS_VERSION.toByte(), AUTH_NONE.toByte()))
output.flush()
}
private fun readRequest(input: InputStream): SocksRequest {
val version = input.readRequired()
val command = input.readRequired()
input.readRequired()
val addressType = input.readRequired()
require(version == SOCKS_VERSION) { "unsupported request version" }
val destination = readAddress(input, addressType)
return when (command) {
COMMAND_CONNECT -> SocksRequest.TcpConnect(destination)
COMMAND_FORWARD_UDP -> SocksRequest.UdpForward
else -> error("unsupported command")
}
}
private fun readAddress(input: InputStream, addressType: Int): InetSocketAddress {
val host = when (addressType) {
ADDRESS_IPV4 -> input.readBytesExact(IPV4_BYTES).joinToString(".") { (it.toInt() and 0xff).toString() }
ADDRESS_DOMAIN -> {
val length = input.readRequired()
input.readBytesExact(length).toString(Charsets.UTF_8)
}
ADDRESS_IPV6 -> InetAddress.getByAddress(input.readBytesExact(IPV6_BYTES)).hostAddress
else -> error("unsupported address type")
}
val port = (input.readRequired() shl 8) or input.readRequired()
return InetSocketAddress(host, port)
}
private fun writeSuccess(output: OutputStream, destination: InetSocketAddress?) {
val port = destination?.port ?: 0
output.write(
byteArrayOf(
SOCKS_VERSION.toByte(),
REPLY_SUCCESS.toByte(),
0,
ADDRESS_IPV4.toByte(),
127,
0,
0,
1,
((port ushr 8) and 0xff).toByte(),
(port and 0xff).toByte()
)
)
output.flush()
}
private fun forwardUdpInTcp(input: InputStream, output: OutputStream) {
DatagramSocket().use { udpSocket ->
val alive = AtomicBoolean(true)
val receiver = Thread {
receiveUdpResponses(udpSocket, output, alive)
}
receiver.name = "vpnshare-socks-udp-receiver"
receiver.start()
try {
while (alive.get()) {
val frame = readUdpTcpFrame(input)
val packet = DatagramPacket(frame.payload, frame.payload.size, frame.destination)
udpSocket.send(packet)
}
} catch (_: EOFException) {
} catch (_: SocketTimeoutException) {
} catch (_: Exception) {
} finally {
alive.set(false)
udpSocket.close()
receiver.join(UDP_RECEIVER_JOIN_MS)
}
}
}
private fun readUdpTcpFrame(input: InputStream): UdpFrame {
val dataLength = (input.readRequired() shl 8) or input.readRequired()
val headerLength = input.readRequired()
require(dataLength <= UDP_MAX_PAYLOAD_BYTES) { "udp frame too large" }
require(headerLength >= UDP_MIN_HEADER_BYTES) { "udp header too small" }
val addressType = input.readRequired()
val destination = readAddress(input, addressType)
val payload = input.readBytesExact(dataLength)
return UdpFrame(destination, payload)
}
private fun receiveUdpResponses(
udpSocket: DatagramSocket,
output: OutputStream,
alive: AtomicBoolean
) {
udpSocket.soTimeout = UDP_RECEIVE_TIMEOUT_MS
val buffer = ByteArray(UDP_RECEIVE_BUFFER_BYTES)
while (alive.get()) {
try {
val packet = DatagramPacket(buffer, buffer.size)
udpSocket.receive(packet)
val response = encodeUdpTcpFrame(packet.address, packet.port, packet.data, packet.offset, packet.length)
synchronized(output) {
output.write(response)
output.flush()
}
} catch (_: SocketTimeoutException) {
} catch (_: Exception) {
alive.set(false)
}
}
}
private fun encodeUdpTcpFrame(
address: InetAddress,
port: Int,
payload: ByteArray,
offset: Int,
length: Int
): ByteArray {
require(length <= UDP_MAX_PAYLOAD_BYTES) { "udp response too large" }
val addressBytes = address.address
val addressType = when (addressBytes.size) {
IPV4_BYTES -> ADDRESS_IPV4
IPV6_BYTES -> ADDRESS_IPV6
else -> error("unsupported udp response address")
}
val socksAddressLength = 1 + addressBytes.size + 2
val headerLength = UDP_PREFIX_BYTES + socksAddressLength
val frame = ByteArray(headerLength + length)
frame[0] = ((length ushr 8) and 0xff).toByte()
frame[1] = (length and 0xff).toByte()
frame[2] = headerLength.toByte()
frame[3] = addressType.toByte()
addressBytes.copyInto(frame, destinationOffset = 4)
val portOffset = 4 + addressBytes.size
frame[portOffset] = ((port ushr 8) and 0xff).toByte()
frame[portOffset + 1] = (port and 0xff).toByte()
payload.copyInto(
frame,
destinationOffset = headerLength,
startIndex = offset,
endIndex = offset + length
)
return frame
}
private fun writeFailure(output: OutputStream) {
output.write(byteArrayOf(SOCKS_VERSION.toByte(), REPLY_FAILURE.toByte(), 0, ADDRESS_IPV4.toByte(), 0, 0, 0, 0, 0, 0))
output.flush()
}
private fun pipeBothWays(left: Socket, right: Socket) {
val leftToRight = Thread { copy(left.getInputStream(), right.getOutputStream(), right) }
val rightToLeft = Thread { copy(right.getInputStream(), left.getOutputStream(), left) }
leftToRight.name = "vpnshare-socks-client-to-phone"
rightToLeft.name = "vpnshare-socks-phone-to-client"
leftToRight.start()
rightToLeft.start()
leftToRight.join()
rightToLeft.join()
}
private fun copy(input: InputStream, output: OutputStream, socketToClose: Socket) {
val buffer = ByteArray(COPY_BUFFER_BYTES)
try {
while (true) {
val read = input.read(buffer)
if (read < 0) break
output.write(buffer, 0, read)
output.flush()
}
} catch (_: SocketTimeoutException) {
} catch (_: Exception) {
} finally {
runCatching { socketToClose.shutdownOutput() }
}
}
private fun InputStream.readRequired(): Int {
val value = read()
if (value < 0) throw EOFException()
return value
}
private fun InputStream.readBytesExact(size: Int): ByteArray {
val bytes = ByteArray(size)
readFully(bytes)
return bytes
}
private fun InputStream.readFully(bytes: ByteArray) {
var offset = 0
while (offset < bytes.size) {
val read = read(bytes, offset, bytes.size - offset)
if (read < 0) throw EOFException()
offset += read
}
}
private sealed interface SocksRequest {
data class TcpConnect(val destination: InetSocketAddress) : SocksRequest
data object UdpForward : SocksRequest
}
private data class UdpFrame(
val destination: InetSocketAddress,
val payload: ByteArray
)
companion object {
const val DEFAULT_PORT = 10808
private const val SOCKS_VERSION = 5
private const val AUTH_NONE = 0
private const val COMMAND_CONNECT = 1
private const val COMMAND_FORWARD_UDP = 5
private const val ADDRESS_IPV4 = 1
private const val ADDRESS_DOMAIN = 3
private const val ADDRESS_IPV6 = 4
private const val REPLY_SUCCESS = 0
private const val REPLY_FAILURE = 1
private const val IPV4_BYTES = 4
private const val IPV6_BYTES = 16
private const val CONNECT_TIMEOUT_MS = 15_000
private const val HANDSHAKE_TIMEOUT_MS = 10_000
private const val COPY_BUFFER_BYTES = 32 * 1024
private const val UDP_PREFIX_BYTES = 3
private const val UDP_MIN_HEADER_BYTES = UDP_PREFIX_BYTES + 7
private const val UDP_MAX_PAYLOAD_BYTES = 65_507
private const val UDP_RECEIVE_BUFFER_BYTES = 65_535
private const val UDP_RECEIVE_TIMEOUT_MS = 1_000
private const val UDP_RECEIVER_JOIN_MS = 1_000L
}
}