diff --git a/go.mod b/go.mod index 5431ee7..db370e4 100644 --- a/go.mod +++ b/go.mod @@ -29,5 +29,6 @@ require ( golang.org/x/exp v0.0.0-20220303212507-bbda1eaf7a17 // indirect golang.org/x/net v0.7.0 // indirect golang.org/x/sys v0.12.0 // indirect + golang.org/x/time v0.5.0 // indirect google.golang.org/protobuf v1.28.1 // indirect ) diff --git a/go.sum b/go.sum index af3a18f..45ac4e4 100644 --- a/go.sum +++ b/go.sum @@ -137,6 +137,8 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= +golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= diff --git a/helpers.go b/helpers.go index 0bd76df..339435c 100644 --- a/helpers.go +++ b/helpers.go @@ -1,17 +1,72 @@ package main import ( + "context" "io" "net/http" + "regexp" "time" "github.com/rs/zerolog/log" + "golang.org/x/time/rate" ) +var DomainLimiters = map[string]*rate.Limiter{ + "utsa.edu": rate.NewLimiter(2, 5), +} + +func GetLimiter(domain string) *rate.Limiter { + // Naively simplify the domain + simplifiedDomain := SimplifyUrlToDomain(domain) + if simplifiedDomain != domain { + log.Debug().Str("domain", domain).Str("simplified", simplifiedDomain).Msg("Domain Simplified") + } + + // Get the limiter + limiter, ok := DomainLimiters[simplifiedDomain] + + // Create a new limiter if one does not exist + if !ok { + limiter = rate.NewLimiter(1, 3) + DomainLimiters[simplifiedDomain] = limiter + log.Debug().Str("domain", domain).Msg("New Limiter Created") + } + return limiter +} + +var DomainPattern = regexp.MustCompile(`(?:\w+\.)*(\w+\.\w+)(?:\/)?`) + +func SimplifyUrlToDomain(url string) string { + // Find the domain + matches := DomainPattern.FindStringSubmatch(url) + if len(matches) == 0 { + return "" + } + return matches[1] +} + +func Wait(limiter *rate.Limiter, ctx context.Context) { + r := limiter.Reserve() + if !r.OK() { + log.Warn().Msg("Rate Limit Exceeded") + return + } + + // Wait for the limiter + if r.Delay() > 0 { + log.Debug().Str("delay", r.Delay().String()).Msg("Waiting") + time.Sleep(r.Delay()) + } +} + // DoRequestNoRead makes a request and returns the response // Compared to DoRequest, this function does not read the response body, and it uses the Content-Length header for the associated log attribute. // This function encapsulates the boilerplate for logging. func DoRequestNoRead(req *http.Request) (*http.Response, error) { + // Acquire the limiter, and wait for a token + limiter := GetLimiter(req.URL.Host) + Wait(limiter, req.Context()) + // Log the request log.Debug().Str("method", req.Method).Str("host", req.Host).Str("path", req.URL.Path).Msg("Request") @@ -34,6 +89,10 @@ func DoRequestNoRead(req *http.Request) (*http.Response, error) { // DoRequest makes a request and returns the response and body // This function encapsulates the boilerplate for logging and reading the response body func DoRequest(req *http.Request) (*http.Response, []byte, error) { + // Acquire the limiter, and wait for a token + limiter := GetLimiter(req.URL.Host) + Wait(limiter, req.Context()) + // Log the request log.Debug().Str("method", req.Method).Str("host", req.Host).Str("path", req.URL.Path).Msg("Request")