more
This commit is contained in:
parent
7c6dca6e40
commit
a5ae45e0b9
30
main.go
30
main.go
|
@ -2,7 +2,6 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"socks5"
|
"socks5"
|
||||||
"golang.org/x/net/context"
|
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
@ -14,38 +13,53 @@ func main() {
|
||||||
args := os.Args[1:]
|
args := os.Args[1:]
|
||||||
|
|
||||||
if len(args) < 2 {
|
if len(args) < 2 {
|
||||||
fmt.Printf("usage: %s bindaddr upstreamaddr\n", os.Args[0])
|
fmt.Printf("usage: %s bindaddr onionsocksaddr [i2psocksaddr]\n", os.Args[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
var onion, i2p proxy.Dialer
|
||||||
upstream, err:= proxy.SOCKS5("tcp", os.Args[2], nil, nil)
|
var err error
|
||||||
|
onion, err = proxy.SOCKS5("tcp", os.Args[2], nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to create upstream proxy to %s, %s", os.Args[2], err.Error())
|
fmt.Printf("failed to create onion proxy to %s, %s\n", os.Args[2], err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if len(args) > 3 {
|
||||||
|
i2p, err = proxy.SOCKS5("tcp", os.Args[3], nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Printf("failed to create i2p proxy to %s, %s\n", os.Args[3], err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
serv, err := socks5.New(&socks5.Config{
|
serv, err := socks5.New(&socks5.Config{
|
||||||
Dial: func(addr string) (net.Conn, error) {
|
Dial: func(addr string) (net.Conn, error) {
|
||||||
host, _, err := net.SplitHostPort(addr)
|
host, _, err := net.SplitHostPort(addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
if strings.HasSuffix(host, ".i2p") {
|
||||||
|
if i2p == nil {
|
||||||
|
return onion.Dial("tcp", addr)
|
||||||
|
}
|
||||||
|
return i2p.Dial("tcp", addr)
|
||||||
|
}
|
||||||
if strings.HasSuffix(host, ".onion") {
|
if strings.HasSuffix(host, ".onion") {
|
||||||
return upstream.Dial("tcp", addr)
|
return onion.Dial("tcp", addr)
|
||||||
}
|
}
|
||||||
return net.Dial("tcp", addr)
|
return net.Dial("tcp", addr)
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to create socks proxy %s", err.Error())
|
fmt.Printf("failed to create socks proxy %s\n", err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
l, err := net.Listen("tcp", os.Args[1])
|
l, err := net.Listen("tcp", os.Args[1])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("failed to listen on %s, %s", os.Args[1], err.Error())
|
fmt.Printf("failed to listen on %s, %s\n", os.Args[1], err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
fmt.Printf("proxy serving on %s\n", os.Args[1])
|
||||||
serv.Serve(l)
|
serv.Serve(l)
|
||||||
}
|
}
|
||||||
|
|
|
@ -79,43 +79,37 @@ func (req *Request) ConnectAddress() string {
|
||||||
return req.DestAddr.Address()
|
return req.DestAddr.Address()
|
||||||
}
|
}
|
||||||
|
|
||||||
type conn interface {
|
type socksConn interface {
|
||||||
io.WriteCloser
|
io.WriteCloser
|
||||||
RemoteAddr() net.Addr
|
RemoteAddr() net.Addr
|
||||||
CloseWrite() error
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewRequest creates a new Request from the tcp connection
|
// readRequest creates a new Request from the tcp connection
|
||||||
func NewRequest(bufConn io.Reader) (*Request, error) {
|
func readRequest(bufConn io.Reader, req *Request) error {
|
||||||
// Read the version byte
|
// Read the version byte
|
||||||
header := []byte{0, 0, 0}
|
header := []byte{0, 0, 0}
|
||||||
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
|
||||||
return nil, fmt.Errorf("Failed to get command version: %v", err)
|
return fmt.Errorf("Failed to get command version: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure we are compatible
|
// Ensure we are compatible
|
||||||
if header[0] != socks5Version {
|
if header[0] != socks5Version {
|
||||||
return nil, fmt.Errorf("Unsupported command version: %v", header[0])
|
return fmt.Errorf("Unsupported command version: %v", header[0])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read in the destination address
|
// Read in the destination address
|
||||||
dest, err := readAddrSpec(bufConn)
|
err := readAddrSpec(bufConn, &req.DestAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
req.Version = socks5Version
|
||||||
request := &Request{
|
req.Command = header[1]
|
||||||
Version: socks5Version,
|
req.bufConn = bufConn
|
||||||
Command: header[1],
|
return nil
|
||||||
DestAddr: dest,
|
|
||||||
bufConn: bufConn,
|
|
||||||
}
|
|
||||||
|
|
||||||
return request, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleRequest is used for request processing after authentication
|
// handleRequest is used for request processing after authentication
|
||||||
func (s *Server) handleRequest(req *Request, conn conn) error {
|
func (s *Server) handleRequest(req *Request, conn socksConn) error {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
// Switch on the command
|
// Switch on the command
|
||||||
|
@ -124,8 +118,6 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||||
return s.handleConnect(ctx, conn, req)
|
return s.handleConnect(ctx, conn, req)
|
||||||
case BindCommand:
|
case BindCommand:
|
||||||
return s.handleBind(ctx, conn, req)
|
return s.handleBind(ctx, conn, req)
|
||||||
case AssociateCommand:
|
|
||||||
return s.handleAssociate(ctx, conn, req)
|
|
||||||
default:
|
default:
|
||||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
return fmt.Errorf("Failed to send reply: %v", err)
|
||||||
|
@ -135,7 +127,7 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleConnect is used to handle a connect command
|
// handleConnect is used to handle a connect command
|
||||||
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error {
|
func (s *Server) handleConnect(ctx context.Context, conn socksConn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
|
@ -149,11 +141,11 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
|
||||||
// Attempt to connect
|
// Attempt to connect
|
||||||
dial := s.config.Dial
|
dial := s.config.Dial
|
||||||
if dial == nil {
|
if dial == nil {
|
||||||
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) {
|
dial = func(addr string) (net.Conn, error) {
|
||||||
return net.Dial(net_, addr)
|
return net.Dial("tcp", addr)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
target, err := dial(ctx, "tcp", req.Address())
|
target, err := dial(req.ConnectAddress())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
msg := err.Error()
|
msg := err.Error()
|
||||||
resp := hostUnreachable
|
resp := hostUnreachable
|
||||||
|
@ -193,7 +185,7 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleBind is used to handle a connect command
|
// handleBind is used to handle a connect command
|
||||||
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error {
|
func (s *Server) handleBind(ctx context.Context, conn socksConn, req *Request) error {
|
||||||
// Check if this is allowed
|
// Check if this is allowed
|
||||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
||||||
|
@ -211,34 +203,14 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleAssociate is used to handle a connect command
|
|
||||||
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
|
|
||||||
// Check if this is allowed
|
|
||||||
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
|
|
||||||
if err := sendReply(conn, ruleFailure, nil); err != nil {
|
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
|
||||||
}
|
|
||||||
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
|
|
||||||
} else {
|
|
||||||
ctx = ctx_
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Support associate
|
|
||||||
if err := sendReply(conn, commandNotSupported, nil); err != nil {
|
|
||||||
return fmt.Errorf("Failed to send reply: %v", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// readAddrSpec is used to read AddrSpec.
|
// readAddrSpec is used to read AddrSpec.
|
||||||
// Expects an address type byte, follwed by the address and port
|
// Expects an address type byte, follwed by the address and port
|
||||||
func readAddrSpec(r io.Reader) (*AddrSpec, error) {
|
func readAddrSpec(r io.Reader, d *AddrSpec) error {
|
||||||
d := &AddrSpec{}
|
|
||||||
|
|
||||||
// Get the address type
|
// Get the address type
|
||||||
addrType := []byte{0}
|
addrType := []byte{0}
|
||||||
if _, err := r.Read(addrType); err != nil {
|
if _, err := r.Read(addrType); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle on a per type basis
|
// Handle on a per type basis
|
||||||
|
@ -246,40 +218,40 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) {
|
||||||
case ipv4Address:
|
case ipv4Address:
|
||||||
addr := make([]byte, 4)
|
addr := make([]byte, 4)
|
||||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
d.IP = net.IP(addr)
|
d.IP = net.IP(addr)
|
||||||
|
|
||||||
case ipv6Address:
|
case ipv6Address:
|
||||||
addr := make([]byte, 16)
|
addr := make([]byte, 16)
|
||||||
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
d.IP = net.IP(addr)
|
d.IP = net.IP(addr)
|
||||||
|
|
||||||
case fqdnAddress:
|
case fqdnAddress:
|
||||||
if _, err := r.Read(addrType); err != nil {
|
if _, err := r.Read(addrType); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
addrLen := int(addrType[0])
|
addrLen := int(addrType[0])
|
||||||
fqdn := make([]byte, addrLen)
|
fqdn := make([]byte, addrLen)
|
||||||
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
|
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
d.FQDN = string(fqdn)
|
d.FQDN = string(fqdn)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nil, unrecognizedAddrType
|
return unrecognizedAddrType
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read the port
|
// Read the port
|
||||||
port := []byte{0, 0}
|
port := []byte{0, 0}
|
||||||
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
|
if _, err := io.ReadAtLeast(r, port, 2); err != nil {
|
||||||
return nil, err
|
return err
|
||||||
}
|
}
|
||||||
d.Port = (int(port[0]) << 8) | int(port[1])
|
d.Port = (int(port[0]) << 8) | int(port[1])
|
||||||
|
|
||||||
return d, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendReply is used to send a reply message
|
// sendReply is used to send a reply message
|
||||||
|
@ -331,9 +303,9 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
|
||||||
|
|
||||||
// proxy is used to suffle data from src to destination, and sends errors
|
// proxy is used to suffle data from src to destination, and sends errors
|
||||||
// down a dedicated channel
|
// down a dedicated channel
|
||||||
func proxy(dst conn, src io.Reader, errCh chan error) {
|
func proxy(dst socksConn, src io.Reader, errCh chan error) {
|
||||||
var buf [1024 * 4]byte
|
var buf [1024 * 4]byte
|
||||||
_, err := io.CopyBuffer(dst, src, buf[:])
|
_, err := io.CopyBuffer(dst, src, buf[:])
|
||||||
dst.CloseWrite()
|
dst.Close()
|
||||||
errCh <- err
|
errCh <- err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,8 +6,6 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
"golang.org/x/net/context"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -129,8 +127,8 @@ func (s *Server) ServeConn(conn net.Conn) error {
|
||||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
var req Request
|
||||||
request, err := NewRequest(bufConn)
|
err = readRequest(bufConn, &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == unrecognizedAddrType {
|
if err == unrecognizedAddrType {
|
||||||
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
|
||||||
|
@ -139,13 +137,13 @@ func (s *Server) ServeConn(conn net.Conn) error {
|
||||||
}
|
}
|
||||||
return fmt.Errorf("Failed to read destination address: %v", err)
|
return fmt.Errorf("Failed to read destination address: %v", err)
|
||||||
}
|
}
|
||||||
request.AuthContext = authContext
|
req.AuthContext = authContext
|
||||||
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
|
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
|
||||||
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
|
req.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process the client request
|
// Process the client request
|
||||||
if err := s.handleRequest(request, conn); err != nil {
|
if err := s.handleRequest(&req, conn); err != nil {
|
||||||
err = fmt.Errorf("Failed to handle request: %v", err)
|
err = fmt.Errorf("Failed to handle request: %v", err)
|
||||||
s.config.Logger.Printf("[ERR] socks: %v", err)
|
s.config.Logger.Printf("[ERR] socks: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|
Loading…
Reference in New Issue