diff options
Diffstat (limited to 'internal')
| -rw-r--r-- | internal/auth/auth.go | 68 | ||||
| -rw-r--r-- | internal/auth/login/index.html | 123 | ||||
| -rw-r--r-- | internal/cmd/root.go | 111 | ||||
| -rw-r--r-- | internal/middleware/middleware.go | 38 | ||||
| -rw-r--r-- | internal/proxies/proxy.go | 1 |
5 files changed, 306 insertions, 35 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go index 2826a47..b58af82 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -5,6 +5,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/golang-jwt/jwt/v5" "github.com/joho/godotenv" @@ -14,43 +15,72 @@ func init() { godotenv.Load() } -var ( - jwtsecret = []byte(os.Getenv("JWT_SECRET")) - algo = string(os.Getenv("JWT_ALGO")) -) +var jwtSecret = []byte(os.Getenv("JWT_SECRET")) + +const algo = "HS256" + +func GenerateJWT(userID string, duration time.Duration) (string, error) { + claims := jwt.MapClaims{ + "sub": userID, + "exp": time.Now().Add(duration).Unix(), + "iat": time.Now().Unix(), + } + + token := jwt.NewWithClaims(jwt.GetSigningMethod(algo), claims) + signedToken, err := token.SignedString(jwtSecret) + if err != nil { + return "", fmt.Errorf("failed to sign token: %w", err) + } + + return signedToken, nil +} -func validateJWT(tokenString string, expectedAlg string) (bool, error) { - token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) { +func validateJWT(tokenString string, expectedAlg string) (jwt.MapClaims, error) { + token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) { if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { return nil, fmt.Errorf("invalid token: %v", token.Header["alg"]) } - if token.Header["alg"] != expectedAlg { - return nil, fmt.Errorf("incorrect alg: %v", token.Header["alg"]) + if token.Method != jwt.GetSigningMethod(expectedAlg) { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) } - return jwtsecret, nil + return jwtSecret, nil }) - if err != nil || !token.Valid { - return false, err + + if err != nil { + return nil, err + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok || !token.Valid { + return nil, fmt.Errorf("JWT validation failed: invalid claims") } - if _, ok := token.Claims.(jwt.MapClaims); ok && token.Valid { - return true, nil + if exp, ok := claims["exp"].(float64); ok { + if int64(exp) < time.Now().Unix() { + return nil, fmt.Errorf("JWT validation failed: token expired") + } } - return false, fmt.Errorf("invalid token") + + return claims, nil } -func Verifyrequest(r *http.Request) bool { +func VerifyRequest(r *http.Request) (jwt.MapClaims, error) { authHeader := r.Header.Get("Authorization") if authHeader == "" { - return false + return nil, fmt.Errorf("Missing Authorization header in request") } // expected: "Bearer <token>" authSplit := strings.Split(authHeader, " ") if len(authSplit) != 2 || authSplit[0] != "Bearer" { - return false + return nil, fmt.Errorf("Malformed Authorization header in request") } token := authSplit[1] - isValid, _ := validateJWT(token, algo) - return isValid + claims, err := validateJWT(token, algo) + if err != nil { + fmt.Printf("JWT error: %v for token: %s\n", err, token) + return nil, err + } + + return claims, nil } diff --git a/internal/auth/login/index.html b/internal/auth/login/index.html new file mode 100644 index 0000000..a42aef7 --- /dev/null +++ b/internal/auth/login/index.html @@ -0,0 +1,123 @@ +<!doctype html> +<html lang="en"> + <head> + <meta charset="UTF-8" /> + <meta name="viewport" content="width=device-width, initial-scale=1.0" /> + <title>Login</title> + <link + href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap" + rel="stylesheet" + /> + <style> + body { + margin: 0; + font-family: "Inter", sans-serif; + background: linear-gradient(135deg, #667eea, #764ba2); + display: flex; + justify-content: center; + align-items: center; + height: 100vh; + } + + .login-card { + background: #fff; + padding: 2rem; + border-radius: 16px; + box-shadow: 0 12px 30px rgba(0, 0, 0, 0.2); + max-width: 400px; + width: 100%; + box-sizing: border-box; + animation: fadeIn 0.5s ease-in-out; + } + + .login-card h2 { + margin-top: 0; + margin-bottom: 1.5rem; + font-size: 1.75rem; + color: #333; + text-align: center; + } + + .form-group { + margin-bottom: 1rem; + } + + label { + display: block; + margin-bottom: 0.5rem; + font-weight: 600; + color: #555; + } + + input[type="text"], + input[type="password"] { + width: 100%; + padding: 0.75rem; + border: 1px solid #ccc; + border-radius: 8px; + font-size: 1rem; + transition: border-color 0.2s; + } + + input[type="text"]:focus, + input[type="password"]:focus { + border-color: #667eea; + outline: none; + } + + button { + width: 100%; + padding: 0.75rem; + background-color: #667eea; + border: none; + border-radius: 8px; + color: white; + font-size: 1rem; + font-weight: 600; + cursor: pointer; + transition: background-color 0.3s ease; + } + + button:hover { + background-color: #5a67d8; + } + + .footer { + margin-top: 1rem; + text-align: center; + font-size: 0.9rem; + color: #888; + } + + @keyframes fadeIn { + from { + opacity: 0; + transform: scale(0.95); + } + to { + opacity: 1; + transform: scale(1); + } + } + </style> + </head> + <body> + <div class="login-card"> + <h2>Login</h2> + <form method="POST" action="/login"> + <div class="form-group"> + <label for="username">Username</label> + <input type="text" id="username" name="username" required /> + </div> + + <div class="form-group"> + <label for="password">Password</label> + <input type="password" id="password" name="password" required /> + </div> + + <button type="submit">Sign In</button> + </form> + <div class="footer">© Wacky404 Reverse Proxy Server</div> + </div> + </body> +</html> diff --git a/internal/cmd/root.go b/internal/cmd/root.go index 6d9c271..1e24015 100644 --- a/internal/cmd/root.go +++ b/internal/cmd/root.go @@ -2,43 +2,103 @@ package cmd import ( "context" + "fmt" "log" "net/http" "net/http/httputil" "net/url" + "sync" "time" "github.com/Wacky404/rpserver/internal/auth" + "github.com/Wacky404/rpserver/internal/middleware" + "github.com/golang-jwt/jwt/v5" ) -func ExecuteServer(portNum string) error { - // don't know if want to keep handlers in here - http.HandleFunc("/proxy", handleProxy) +var ( + proxyCache = make(map[string]*httputil.ReverseProxy) + cacheMutex sync.RWMutex +) + +func ExecuteServer(port string, cert string, key string) error { + mux := http.NewServeMux() - log.Printf("Reverse Proxy running on :%v", portNum) - err := http.ListenAndServe(":"+portNum, nil) + mux.Handle("/", middleware.Recover(http.HandlerFunc(serveLoginPage))) + mux.Handle("/auth/login", middleware.Recover(http.HandlerFunc(handleLogin))) + mux.Handle("/auth/signup", middleware.Recover(http.HandlerFunc(serveSignUpPage))) + mux.Handle("/proxy", middleware.Recover(middleware.JWT(http.HandlerFunc(handleProxy)))) + mux.Handle("/status", middleware.Recover(http.HandlerFunc(handleStatus))) + err := http.ListenAndServeTLS(port, cert, key, mux) return err } +func serveSignUpPage(w http.ResponseWriter, r *http.Request) {} + +func handleSignUp(w http.ResponseWriter, r *http.Request) {} + +func serveLoginPage(w http.ResponseWriter, r *http.Request) { + log.Println("Trying to server the login page!!!") + http.ServeFile(w, r, "internal/auth/login/index.html") +} + +/* This function doesn't have proper auth for login creds */ +func handleLogin(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + + username := r.FormValue("username") + password := r.FormValue("password") + + // pull this out into auth function + if username == "admin" && password == "password4321" { + token, err := auth.GenerateJWT(username, time.Hour) + if err != nil { + log.Printf("JWT generation error: %v", err) + http.Error(w, "Could not generate token:", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "application/json") + fmt.Fprintf(w, `{"token": "%s"}`, token) + + return + } + + http.Redirect(w, r, "/", http.StatusSeeOther) +} + +func handleStatus(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprintln(w, "200 OK") +} + func handleProxy(w http.ResponseWriter, r *http.Request) { - if !auth.Verifyrequest(r) { - http.Error(w, "Unauthorized", http.StatusUnauthorized) + claims, ok := r.Context().Value("claims").(jwt.MapClaims) + if !ok { + fmt.Println("Is this failing...") + http.Error(w, "Failed to get JWT claims", http.StatusInternalServerError) return } + userID := claims["sub"] + role := claims["role"] + fmt.Printf("%v, %v", userID, role) backendURL, err := getBackendURL(r) if err != nil { + fmt.Println("Is this failing...2") http.Error(w, "Backend URL not provided", http.StatusBadRequest) return } - proxy := httputil.NewSingleHostReverseProxy(backendURL) + proxy := getOrCreateProxy(backendURL) ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second) defer cancel() - r = r.WithContext(ctx) // attaching the new ctx to request + r = r.WithContext(ctx) proxy.Director = func(req *http.Request) { req.URL.Scheme = backendURL.Scheme @@ -49,24 +109,43 @@ func handleProxy(w http.ResponseWriter, r *http.Request) { done := make(chan struct{}) go func() { - proxy.ServeHTTP(w, r) // Forward the request + proxy.ServeHTTP(w, r) close(done) }() select { - case <-ctx.Done(): // if context timout occurs + case <-ctx.Done(): http.Error(w, "Request timed out", http.StatusGatewayTimeout) log.Println("Request to", r.URL.Path, "timed out...balls") - case <-done: // if request completes successfully + case <-done: } } func getBackendURL(r *http.Request) (*url.URL, error) { - backend := r.Header.Get("X-Backend-URL") // extraction of backend from request - + backend := r.Header.Get("X-Backend-URL") if backend == "" { - return nil, http.ErrNoLocation // error if no backend is in request + return nil, http.ErrNoLocation } + return url.Parse(backend) +} + +func getOrCreateProxy(target *url.URL) *httputil.ReverseProxy { + cacheMutex.RLock() + proxy, exists := proxyCache[target.String()] + cacheMutex.RUnlock() + if exists { + return proxy + } + + cacheMutex.Lock() + defer cacheMutex.Unlock() + + if proxy, exists = proxyCache[target.String()]; exists { + return proxy + } + + proxy = httputil.NewSingleHostReverseProxy(target) + proxyCache[target.String()] = proxy - return url.Parse(backend) // parse and return the backend url + return proxy } diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go new file mode 100644 index 0000000..849e705 --- /dev/null +++ b/internal/middleware/middleware.go @@ -0,0 +1,38 @@ +package middleware + +import ( + "context" + "log" + "net/http" + + "github.com/Wacky404/rpserver/internal/auth" +) + +type key string + +const claimsKey key = "claims" + +func JWT(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + claims, err := auth.VerifyRequest(r) + if err != nil { + http.Error(w, "Unauthorized", http.StatusUnauthorized) + return + } + + ctx := context.WithValue(r.Context(), claimsKey, claims) + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + +func Recover(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + defer func() { + if err := recover(); err != nil { + log.Printf("Recovered from panic: %v", err) + http.Error(w, "Internal Server Error", http.StatusInternalServerError) + } + }() + next.ServeHTTP(w, r) + }) +} diff --git a/internal/proxies/proxy.go b/internal/proxies/proxy.go new file mode 100644 index 0000000..769a888 --- /dev/null +++ b/internal/proxies/proxy.go @@ -0,0 +1 @@ +package proxies |
