@@ -28,45 +28,62 @@ type Rule struct {
2828
2929func main () {
3030 var (
31- file string
31+ file string
32+ verbose bool
33+ dialTimeout time.Duration
3234 )
3335
3436 flag .StringVar (& file , "f" , "" , "Job to run or leave blank for job.yaml in current directory" )
35-
37+ flag .BoolVar (& verbose , "v" , true , "Verbose output for opened and closed connections" )
38+ flag .DurationVar (& dialTimeout , "t" , time .Millisecond * 1500 , "Dial timeout" )
3639 flag .Parse ()
3740
41+ if len (file ) == 0 {
42+ fmt .Fprintf (os .Stderr , "usage: mixctl -f rules.yaml\n " )
43+ os .Exit (1 )
44+ }
45+
3846 set := ForwardingSet {}
3947 data , err := os .ReadFile (file )
4048 if err != nil {
41- log .Fatalf ("error reading file %s %s" , file , err .Error ())
49+ fmt .Fprintf (os .Stderr , "error reading file %s %s" , file , err .Error ())
50+ os .Exit (1 )
4251 }
4352 if err = yaml .Unmarshal (data , & set ); err != nil {
44- log .Fatalf ("error parsing file %s %s" , file , err .Error ())
53+ fmt .Fprintf (os .Stderr , "error parsing file %s %s" , file , err .Error ())
54+ os .Exit (1 )
55+ }
56+
57+ if len (set .Rules ) == 0 {
58+ fmt .Fprintf (os .Stderr , "no rules found in file %s" , file )
59+ os .Exit (1 )
4560 }
4661
47- fmt .Printf ("mixctl by inlets.. \n " )
62+ fmt .Printf ("Starting mixctl by https:// inlets.dev/ \n \n " )
4863
4964 wg := sync.WaitGroup {}
5065 wg .Add (len (set .Rules ))
51- for _ , f := range set .Rules {
52-
53- r := f
54- go func (rule * Rule ) {
55- fmt .Printf ("Forward (%s) from: %s to: %s\n " , rule .Name , rule .From , rule .To )
66+ for _ , rule := range set .Rules {
67+ fmt .Printf ("Forward (%s) from: %s to: %s\n " , rule .Name , rule .From , rule .To )
68+ }
69+ fmt .Println ()
5670
57- if err := forward (rule .Name , rule .From , rule .To ); err != nil {
71+ for _ , rule := range set .Rules {
72+ // Copy the value to avoid the loop variable being reused
73+ r := rule
74+ go func () {
75+ if err := forward (r .Name , r .From , r .To , verbose , dialTimeout ); err != nil {
5876 log .Printf ("error forwarding %s" , err .Error ())
5977 os .Exit (1 )
6078 }
61-
6279 defer wg .Done ()
63- }(& r )
80+ }()
6481 }
65- wg .Wait ()
6682
83+ wg .Wait ()
6784}
6885
69- func forward (name , from string , to []string ) error {
86+ func forward (name , from string , to []string , verbose bool , dialTimeout time. Duration ) error {
7087 seed := time .Now ().UnixNano ()
7188 rand .Seed (seed )
7289
@@ -76,42 +93,92 @@ func forward(name, from string, to []string) error {
7693 return fmt .Errorf ("error listening on %s %s" , from , err .Error ())
7794 }
7895
96+ defer l .Close ()
97+
7998 for {
80- conn , err := l .Accept ()
99+ // accept a connection on the local port of the load balancer
100+ local , err := l .Accept ()
81101 if err != nil {
82102 return fmt .Errorf ("error accepting connection %s" , err .Error ())
83103 }
84104
105+ // pick randomly from the list of upstream servers
106+ // available
85107 index := rand .Intn (len (to ))
108+ upstream := to [index ]
86109
87- remote , err := net .Dial ("tcp" , to [index ])
88- if err != nil {
89- return fmt .Errorf ("error dialing %s %s" , to [index ], err .Error ())
90- }
110+ // A separate Goroutine means the loop can accept another
111+ // incoming connection on the local address
112+ go connect (local , upstream , from , verbose , dialTimeout )
113+ }
114+ }
91115
92- go func () {
93- log .Printf ("[%s] %s => %s" ,
94- from ,
95- conn .RemoteAddr ().String (),
96- remote .RemoteAddr ().String ())
97- if err := forwardConnection (conn , remote ); err != nil && err .Error () != "done" {
98- log .Printf ("error forwarding connection %s" , err .Error ())
99- }
100- }()
116+ // connect dials the upstream address, then copies data
117+ // between it and connection accepted on a local port
118+ func connect (local net.Conn , upstreamAddr , from string , verbose bool , dialTimeout time.Duration ) {
119+ defer local .Close ()
120+
121+ // If Dial is used on its own, then the timeout can be as long
122+ // as 2 minutes on MacOS for an unreachable host
123+ upstream , err := net .DialTimeout ("tcp" , upstreamAddr , dialTimeout )
124+ if err != nil {
125+ log .Printf ("error dialing %s %s" , upstreamAddr , err .Error ())
126+ return
127+ }
128+ defer upstream .Close ()
129+
130+ if verbose {
131+ log .Printf ("Connected %s => %s (%s)" ,
132+ from ,
133+ upstream .RemoteAddr ().String (),
134+ local .RemoteAddr ().String ())
135+ }
136+
137+ ctx := context .Background ()
138+ if err := copy (ctx , local , upstream ); err != nil && err .Error () != "done" {
139+ log .Printf ("error forwarding connection %s" , err .Error ())
140+ }
141+
142+ if verbose {
143+ log .Printf ("Closed %s => %s (%s)" ,
144+ from ,
145+ upstream .RemoteAddr ().String (),
146+ local .RemoteAddr ().String ())
101147 }
102148}
103149
104- func forwardConnection (from , to net.Conn ) error {
105- errgrp , _ := errgroup .WithContext (context .Background ())
150+ // copy copies data between two connections using io.Copy
151+ // and will exit when either connection is closed or runs
152+ // into an error
153+ func copy (ctx context.Context , from , to net.Conn ) error {
154+
155+ ctx , cancel := context .WithCancel (ctx )
156+ errgrp , _ := errgroup .WithContext (ctx )
106157 errgrp .Go (func () error {
107158 io .Copy (from , to )
159+ cancel ()
108160
109161 return fmt .Errorf ("done" )
110162 })
111163 errgrp .Go (func () error {
112164 io .Copy (to , from )
165+ cancel ()
166+
113167 return fmt .Errorf ("done" )
114168 })
169+ errgrp .Go (func () error {
170+ <- ctx .Done ()
171+
172+ // This closes both ends of the connection as
173+ // soon as possible.
174+ from .Close ()
175+ to .Close ()
176+ return fmt .Errorf ("done" )
177+ })
178+
179+ if err := errgrp .Wait (); err != nil {
180+ return err
181+ }
115182
116- return errgrp . Wait ()
183+ return nil
117184}
0 commit comments