This commit is contained in:
Jeff Becker 2018-08-29 08:25:27 -04:00
parent 7c6dca6e40
commit a5ae45e0b9
No known key found for this signature in database
GPG Key ID: F357B3B42F6F9B05
3 changed files with 55 additions and 71 deletions

30
main.go
View File

@ -2,7 +2,6 @@ package main
import ( import (
"socks5" "socks5"
"golang.org/x/net/context"
"golang.org/x/net/proxy" "golang.org/x/net/proxy"
"net" "net"
"os" "os"
@ -14,38 +13,53 @@ func main() {
args := os.Args[1:] args := os.Args[1:]
if len(args) < 2 { if len(args) < 2 {
fmt.Printf("usage: %s bindaddr upstreamaddr\n", os.Args[0]) fmt.Printf("usage: %s bindaddr onionsocksaddr [i2psocksaddr]\n", os.Args[0])
return return
} }
var onion, i2p proxy.Dialer
upstream, err:= proxy.SOCKS5("tcp", os.Args[2], nil, nil) var err error
onion, err = proxy.SOCKS5("tcp", os.Args[2], nil, nil)
if err != nil { if err != nil {
fmt.Printf("failed to create upstream proxy to %s, %s", os.Args[2], err.Error()) fmt.Printf("failed to create onion proxy to %s, %s\n", os.Args[2], err.Error())
return return
} }
if len(args) > 3 {
i2p, err = proxy.SOCKS5("tcp", os.Args[3], nil, nil)
if err != nil {
fmt.Printf("failed to create i2p proxy to %s, %s\n", os.Args[3], err.Error())
return
}
}
serv, err := socks5.New(&socks5.Config{ serv, err := socks5.New(&socks5.Config{
Dial: func(addr string) (net.Conn, error) { Dial: func(addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr) host, _, err := net.SplitHostPort(addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if strings.HasSuffix(host, ".i2p") {
if i2p == nil {
return onion.Dial("tcp", addr)
}
return i2p.Dial("tcp", addr)
}
if strings.HasSuffix(host, ".onion") { if strings.HasSuffix(host, ".onion") {
return upstream.Dial("tcp", addr) return onion.Dial("tcp", addr)
} }
return net.Dial("tcp", addr) return net.Dial("tcp", addr)
}, },
}) })
if err != nil { if err != nil {
fmt.Printf("failed to create socks proxy %s", err.Error()) fmt.Printf("failed to create socks proxy %s\n", err.Error())
return return
} }
l, err := net.Listen("tcp", os.Args[1]) l, err := net.Listen("tcp", os.Args[1])
if err != nil { if err != nil {
fmt.Printf("failed to listen on %s, %s", os.Args[1], err.Error()) fmt.Printf("failed to listen on %s, %s\n", os.Args[1], err.Error())
return return
} }
fmt.Printf("proxy serving on %s\n", os.Args[1])
serv.Serve(l) serv.Serve(l)
} }

View File

@ -79,43 +79,37 @@ func (req *Request) ConnectAddress() string {
return req.DestAddr.Address() return req.DestAddr.Address()
} }
type conn interface { type socksConn interface {
io.WriteCloser io.WriteCloser
RemoteAddr() net.Addr RemoteAddr() net.Addr
CloseWrite() error
} }
// NewRequest creates a new Request from the tcp connection // readRequest creates a new Request from the tcp connection
func NewRequest(bufConn io.Reader) (*Request, error) { func readRequest(bufConn io.Reader, req *Request) error {
// Read the version byte // Read the version byte
header := []byte{0, 0, 0} header := []byte{0, 0, 0}
if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil { if _, err := io.ReadAtLeast(bufConn, header, 3); err != nil {
return nil, fmt.Errorf("Failed to get command version: %v", err) return fmt.Errorf("Failed to get command version: %v", err)
} }
// Ensure we are compatible // Ensure we are compatible
if header[0] != socks5Version { if header[0] != socks5Version {
return nil, fmt.Errorf("Unsupported command version: %v", header[0]) return fmt.Errorf("Unsupported command version: %v", header[0])
} }
// Read in the destination address // Read in the destination address
dest, err := readAddrSpec(bufConn) err := readAddrSpec(bufConn, &req.DestAddr)
if err != nil { if err != nil {
return nil, err return err
} }
req.Version = socks5Version
request := &Request{ req.Command = header[1]
Version: socks5Version, req.bufConn = bufConn
Command: header[1], return nil
DestAddr: dest,
bufConn: bufConn,
}
return request, nil
} }
// handleRequest is used for request processing after authentication // handleRequest is used for request processing after authentication
func (s *Server) handleRequest(req *Request, conn conn) error { func (s *Server) handleRequest(req *Request, conn socksConn) error {
ctx := context.Background() ctx := context.Background()
// Switch on the command // Switch on the command
@ -124,8 +118,6 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
return s.handleConnect(ctx, conn, req) return s.handleConnect(ctx, conn, req)
case BindCommand: case BindCommand:
return s.handleBind(ctx, conn, req) return s.handleBind(ctx, conn, req)
case AssociateCommand:
return s.handleAssociate(ctx, conn, req)
default: default:
if err := sendReply(conn, commandNotSupported, nil); err != nil { if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err) return fmt.Errorf("Failed to send reply: %v", err)
@ -135,7 +127,7 @@ func (s *Server) handleRequest(req *Request, conn conn) error {
} }
// handleConnect is used to handle a connect command // handleConnect is used to handle a connect command
func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) error { func (s *Server) handleConnect(ctx context.Context, conn socksConn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
@ -149,11 +141,11 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
// Attempt to connect // Attempt to connect
dial := s.config.Dial dial := s.config.Dial
if dial == nil { if dial == nil {
dial = func(ctx context.Context, net_, addr string) (net.Conn, error) { dial = func(addr string) (net.Conn, error) {
return net.Dial(net_, addr) return net.Dial("tcp", addr)
} }
} }
target, err := dial(ctx, "tcp", req.Address()) target, err := dial(req.ConnectAddress())
if err != nil { if err != nil {
msg := err.Error() msg := err.Error()
resp := hostUnreachable resp := hostUnreachable
@ -193,7 +185,7 @@ func (s *Server) handleConnect(ctx context.Context, conn conn, req *Request) err
} }
// handleBind is used to handle a connect command // handleBind is used to handle a connect command
func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error { func (s *Server) handleBind(ctx context.Context, conn socksConn, req *Request) error {
// Check if this is allowed // Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok { if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil { if err := sendReply(conn, ruleFailure, nil); err != nil {
@ -211,34 +203,14 @@ func (s *Server) handleBind(ctx context.Context, conn conn, req *Request) error
return nil return nil
} }
// handleAssociate is used to handle a connect command
func (s *Server) handleAssociate(ctx context.Context, conn conn, req *Request) error {
// Check if this is allowed
if ctx_, ok := s.config.Rules.Allow(ctx, req); !ok {
if err := sendReply(conn, ruleFailure, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return fmt.Errorf("Associate to %v blocked by rules", req.DestAddr)
} else {
ctx = ctx_
}
// TODO: Support associate
if err := sendReply(conn, commandNotSupported, nil); err != nil {
return fmt.Errorf("Failed to send reply: %v", err)
}
return nil
}
// readAddrSpec is used to read AddrSpec. // readAddrSpec is used to read AddrSpec.
// Expects an address type byte, follwed by the address and port // Expects an address type byte, follwed by the address and port
func readAddrSpec(r io.Reader) (*AddrSpec, error) { func readAddrSpec(r io.Reader, d *AddrSpec) error {
d := &AddrSpec{}
// Get the address type // Get the address type
addrType := []byte{0} addrType := []byte{0}
if _, err := r.Read(addrType); err != nil { if _, err := r.Read(addrType); err != nil {
return nil, err return err
} }
// Handle on a per type basis // Handle on a per type basis
@ -246,40 +218,40 @@ func readAddrSpec(r io.Reader) (*AddrSpec, error) {
case ipv4Address: case ipv4Address:
addr := make([]byte, 4) addr := make([]byte, 4)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err return err
} }
d.IP = net.IP(addr) d.IP = net.IP(addr)
case ipv6Address: case ipv6Address:
addr := make([]byte, 16) addr := make([]byte, 16)
if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil { if _, err := io.ReadAtLeast(r, addr, len(addr)); err != nil {
return nil, err return err
} }
d.IP = net.IP(addr) d.IP = net.IP(addr)
case fqdnAddress: case fqdnAddress:
if _, err := r.Read(addrType); err != nil { if _, err := r.Read(addrType); err != nil {
return nil, err return err
} }
addrLen := int(addrType[0]) addrLen := int(addrType[0])
fqdn := make([]byte, addrLen) fqdn := make([]byte, addrLen)
if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil { if _, err := io.ReadAtLeast(r, fqdn, addrLen); err != nil {
return nil, err return err
} }
d.FQDN = string(fqdn) d.FQDN = string(fqdn)
default: default:
return nil, unrecognizedAddrType return unrecognizedAddrType
} }
// Read the port // Read the port
port := []byte{0, 0} port := []byte{0, 0}
if _, err := io.ReadAtLeast(r, port, 2); err != nil { if _, err := io.ReadAtLeast(r, port, 2); err != nil {
return nil, err return err
} }
d.Port = (int(port[0]) << 8) | int(port[1]) d.Port = (int(port[0]) << 8) | int(port[1])
return d, nil return nil
} }
// sendReply is used to send a reply message // sendReply is used to send a reply message
@ -331,9 +303,9 @@ func sendReply(w io.Writer, resp uint8, addr *AddrSpec) error {
// proxy is used to suffle data from src to destination, and sends errors // proxy is used to suffle data from src to destination, and sends errors
// down a dedicated channel // down a dedicated channel
func proxy(dst conn, src io.Reader, errCh chan error) { func proxy(dst socksConn, src io.Reader, errCh chan error) {
var buf [1024 * 4]byte var buf [1024 * 4]byte
_, err := io.CopyBuffer(dst, src, buf[:]) _, err := io.CopyBuffer(dst, src, buf[:])
dst.CloseWrite() dst.Close()
errCh <- err errCh <- err
} }

View File

@ -6,8 +6,6 @@ import (
"log" "log"
"net" "net"
"os" "os"
"golang.org/x/net/context"
) )
const ( const (
@ -129,8 +127,8 @@ func (s *Server) ServeConn(conn net.Conn) error {
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err
} }
var req Request
request, err := NewRequest(bufConn) err = readRequest(bufConn, &req)
if err != nil { if err != nil {
if err == unrecognizedAddrType { if err == unrecognizedAddrType {
if err := sendReply(conn, addrTypeNotSupported, nil); err != nil { if err := sendReply(conn, addrTypeNotSupported, nil); err != nil {
@ -139,13 +137,13 @@ func (s *Server) ServeConn(conn net.Conn) error {
} }
return fmt.Errorf("Failed to read destination address: %v", err) return fmt.Errorf("Failed to read destination address: %v", err)
} }
request.AuthContext = authContext req.AuthContext = authContext
if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok { if client, ok := conn.RemoteAddr().(*net.TCPAddr); ok {
request.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port} req.RemoteAddr = &AddrSpec{IP: client.IP, Port: client.Port}
} }
// Process the client request // Process the client request
if err := s.handleRequest(request, conn); err != nil { if err := s.handleRequest(&req, conn); err != nil {
err = fmt.Errorf("Failed to handle request: %v", err) err = fmt.Errorf("Failed to handle request: %v", err)
s.config.Logger.Printf("[ERR] socks: %v", err) s.config.Logger.Printf("[ERR] socks: %v", err)
return err return err