summaryrefslogtreecommitdiff
path: root/internal
diff options
context:
space:
mode:
Diffstat (limited to 'internal')
-rw-r--r--internal/auth/auth.go68
-rw-r--r--internal/auth/login/index.html123
-rw-r--r--internal/cmd/root.go111
-rw-r--r--internal/middleware/middleware.go38
-rw-r--r--internal/proxies/proxy.go1
5 files changed, 306 insertions, 35 deletions
diff --git a/internal/auth/auth.go b/internal/auth/auth.go
index 2826a47..b58af82 100644
--- a/internal/auth/auth.go
+++ b/internal/auth/auth.go
@@ -5,6 +5,7 @@ import (
"net/http"
"os"
"strings"
+ "time"
"github.com/golang-jwt/jwt/v5"
"github.com/joho/godotenv"
@@ -14,43 +15,72 @@ func init() {
godotenv.Load()
}
-var (
- jwtsecret = []byte(os.Getenv("JWT_SECRET"))
- algo = string(os.Getenv("JWT_ALGO"))
-)
+var jwtSecret = []byte(os.Getenv("JWT_SECRET"))
+
+const algo = "HS256"
+
+func GenerateJWT(userID string, duration time.Duration) (string, error) {
+ claims := jwt.MapClaims{
+ "sub": userID,
+ "exp": time.Now().Add(duration).Unix(),
+ "iat": time.Now().Unix(),
+ }
+
+ token := jwt.NewWithClaims(jwt.GetSigningMethod(algo), claims)
+ signedToken, err := token.SignedString(jwtSecret)
+ if err != nil {
+ return "", fmt.Errorf("failed to sign token: %w", err)
+ }
+
+ return signedToken, nil
+}
-func validateJWT(tokenString string, expectedAlg string) (bool, error) {
- token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
+func validateJWT(tokenString string, expectedAlg string) (jwt.MapClaims, error) {
+ token, err := jwt.Parse(tokenString, func(token *jwt.Token) (any, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("invalid token: %v", token.Header["alg"])
}
- if token.Header["alg"] != expectedAlg {
- return nil, fmt.Errorf("incorrect alg: %v", token.Header["alg"])
+ if token.Method != jwt.GetSigningMethod(expectedAlg) {
+ return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
- return jwtsecret, nil
+ return jwtSecret, nil
})
- if err != nil || !token.Valid {
- return false, err
+
+ if err != nil {
+ return nil, err
+ }
+
+ claims, ok := token.Claims.(jwt.MapClaims)
+ if !ok || !token.Valid {
+ return nil, fmt.Errorf("JWT validation failed: invalid claims")
}
- if _, ok := token.Claims.(jwt.MapClaims); ok && token.Valid {
- return true, nil
+ if exp, ok := claims["exp"].(float64); ok {
+ if int64(exp) < time.Now().Unix() {
+ return nil, fmt.Errorf("JWT validation failed: token expired")
+ }
}
- return false, fmt.Errorf("invalid token")
+
+ return claims, nil
}
-func Verifyrequest(r *http.Request) bool {
+func VerifyRequest(r *http.Request) (jwt.MapClaims, error) {
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
- return false
+ return nil, fmt.Errorf("Missing Authorization header in request")
}
// expected: "Bearer <token>"
authSplit := strings.Split(authHeader, " ")
if len(authSplit) != 2 || authSplit[0] != "Bearer" {
- return false
+ return nil, fmt.Errorf("Malformed Authorization header in request")
}
token := authSplit[1]
- isValid, _ := validateJWT(token, algo)
- return isValid
+ claims, err := validateJWT(token, algo)
+ if err != nil {
+ fmt.Printf("JWT error: %v for token: %s\n", err, token)
+ return nil, err
+ }
+
+ return claims, nil
}
diff --git a/internal/auth/login/index.html b/internal/auth/login/index.html
new file mode 100644
index 0000000..a42aef7
--- /dev/null
+++ b/internal/auth/login/index.html
@@ -0,0 +1,123 @@
+<!doctype html>
+<html lang="en">
+ <head>
+ <meta charset="UTF-8" />
+ <meta name="viewport" content="width=device-width, initial-scale=1.0" />
+ <title>Login</title>
+ <link
+ href="https://fonts.googleapis.com/css2?family=Inter:wght@400;600&display=swap"
+ rel="stylesheet"
+ />
+ <style>
+ body {
+ margin: 0;
+ font-family: "Inter", sans-serif;
+ background: linear-gradient(135deg, #667eea, #764ba2);
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ height: 100vh;
+ }
+
+ .login-card {
+ background: #fff;
+ padding: 2rem;
+ border-radius: 16px;
+ box-shadow: 0 12px 30px rgba(0, 0, 0, 0.2);
+ max-width: 400px;
+ width: 100%;
+ box-sizing: border-box;
+ animation: fadeIn 0.5s ease-in-out;
+ }
+
+ .login-card h2 {
+ margin-top: 0;
+ margin-bottom: 1.5rem;
+ font-size: 1.75rem;
+ color: #333;
+ text-align: center;
+ }
+
+ .form-group {
+ margin-bottom: 1rem;
+ }
+
+ label {
+ display: block;
+ margin-bottom: 0.5rem;
+ font-weight: 600;
+ color: #555;
+ }
+
+ input[type="text"],
+ input[type="password"] {
+ width: 100%;
+ padding: 0.75rem;
+ border: 1px solid #ccc;
+ border-radius: 8px;
+ font-size: 1rem;
+ transition: border-color 0.2s;
+ }
+
+ input[type="text"]:focus,
+ input[type="password"]:focus {
+ border-color: #667eea;
+ outline: none;
+ }
+
+ button {
+ width: 100%;
+ padding: 0.75rem;
+ background-color: #667eea;
+ border: none;
+ border-radius: 8px;
+ color: white;
+ font-size: 1rem;
+ font-weight: 600;
+ cursor: pointer;
+ transition: background-color 0.3s ease;
+ }
+
+ button:hover {
+ background-color: #5a67d8;
+ }
+
+ .footer {
+ margin-top: 1rem;
+ text-align: center;
+ font-size: 0.9rem;
+ color: #888;
+ }
+
+ @keyframes fadeIn {
+ from {
+ opacity: 0;
+ transform: scale(0.95);
+ }
+ to {
+ opacity: 1;
+ transform: scale(1);
+ }
+ }
+ </style>
+ </head>
+ <body>
+ <div class="login-card">
+ <h2>Login</h2>
+ <form method="POST" action="/login">
+ <div class="form-group">
+ <label for="username">Username</label>
+ <input type="text" id="username" name="username" required />
+ </div>
+
+ <div class="form-group">
+ <label for="password">Password</label>
+ <input type="password" id="password" name="password" required />
+ </div>
+
+ <button type="submit">Sign In</button>
+ </form>
+ <div class="footer">&copy; Wacky404 Reverse Proxy Server</div>
+ </div>
+ </body>
+</html>
diff --git a/internal/cmd/root.go b/internal/cmd/root.go
index 6d9c271..1e24015 100644
--- a/internal/cmd/root.go
+++ b/internal/cmd/root.go
@@ -2,43 +2,103 @@ package cmd
import (
"context"
+ "fmt"
"log"
"net/http"
"net/http/httputil"
"net/url"
+ "sync"
"time"
"github.com/Wacky404/rpserver/internal/auth"
+ "github.com/Wacky404/rpserver/internal/middleware"
+ "github.com/golang-jwt/jwt/v5"
)
-func ExecuteServer(portNum string) error {
- // don't know if want to keep handlers in here
- http.HandleFunc("/proxy", handleProxy)
+var (
+ proxyCache = make(map[string]*httputil.ReverseProxy)
+ cacheMutex sync.RWMutex
+)
+
+func ExecuteServer(port string, cert string, key string) error {
+ mux := http.NewServeMux()
- log.Printf("Reverse Proxy running on :%v", portNum)
- err := http.ListenAndServe(":"+portNum, nil)
+ mux.Handle("/", middleware.Recover(http.HandlerFunc(serveLoginPage)))
+ mux.Handle("/auth/login", middleware.Recover(http.HandlerFunc(handleLogin)))
+ mux.Handle("/auth/signup", middleware.Recover(http.HandlerFunc(serveSignUpPage)))
+ mux.Handle("/proxy", middleware.Recover(middleware.JWT(http.HandlerFunc(handleProxy))))
+ mux.Handle("/status", middleware.Recover(http.HandlerFunc(handleStatus)))
+ err := http.ListenAndServeTLS(port, cert, key, mux)
return err
}
+func serveSignUpPage(w http.ResponseWriter, r *http.Request) {}
+
+func handleSignUp(w http.ResponseWriter, r *http.Request) {}
+
+func serveLoginPage(w http.ResponseWriter, r *http.Request) {
+ log.Println("Trying to server the login page!!!")
+ http.ServeFile(w, r, "internal/auth/login/index.html")
+}
+
+/* This function doesn't have proper auth for login creds */
+func handleLogin(w http.ResponseWriter, r *http.Request) {
+ if r.Method != http.MethodPost {
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+ return
+ }
+
+ username := r.FormValue("username")
+ password := r.FormValue("password")
+
+ // pull this out into auth function
+ if username == "admin" && password == "password4321" {
+ token, err := auth.GenerateJWT(username, time.Hour)
+ if err != nil {
+ log.Printf("JWT generation error: %v", err)
+ http.Error(w, "Could not generate token:", http.StatusInternalServerError)
+ return
+ }
+
+ w.Header().Set("Content-Type", "application/json")
+ fmt.Fprintf(w, `{"token": "%s"}`, token)
+
+ return
+ }
+
+ http.Redirect(w, r, "/", http.StatusSeeOther)
+}
+
+func handleStatus(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(http.StatusOK)
+ fmt.Fprintln(w, "200 OK")
+}
+
func handleProxy(w http.ResponseWriter, r *http.Request) {
- if !auth.Verifyrequest(r) {
- http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ claims, ok := r.Context().Value("claims").(jwt.MapClaims)
+ if !ok {
+ fmt.Println("Is this failing...")
+ http.Error(w, "Failed to get JWT claims", http.StatusInternalServerError)
return
}
+ userID := claims["sub"]
+ role := claims["role"]
+ fmt.Printf("%v, %v", userID, role)
backendURL, err := getBackendURL(r)
if err != nil {
+ fmt.Println("Is this failing...2")
http.Error(w, "Backend URL not provided", http.StatusBadRequest)
return
}
- proxy := httputil.NewSingleHostReverseProxy(backendURL)
+ proxy := getOrCreateProxy(backendURL)
ctx, cancel := context.WithTimeout(r.Context(), 5*time.Second)
defer cancel()
- r = r.WithContext(ctx) // attaching the new ctx to request
+ r = r.WithContext(ctx)
proxy.Director = func(req *http.Request) {
req.URL.Scheme = backendURL.Scheme
@@ -49,24 +109,43 @@ func handleProxy(w http.ResponseWriter, r *http.Request) {
done := make(chan struct{})
go func() {
- proxy.ServeHTTP(w, r) // Forward the request
+ proxy.ServeHTTP(w, r)
close(done)
}()
select {
- case <-ctx.Done(): // if context timout occurs
+ case <-ctx.Done():
http.Error(w, "Request timed out", http.StatusGatewayTimeout)
log.Println("Request to", r.URL.Path, "timed out...balls")
- case <-done: // if request completes successfully
+ case <-done:
}
}
func getBackendURL(r *http.Request) (*url.URL, error) {
- backend := r.Header.Get("X-Backend-URL") // extraction of backend from request
-
+ backend := r.Header.Get("X-Backend-URL")
if backend == "" {
- return nil, http.ErrNoLocation // error if no backend is in request
+ return nil, http.ErrNoLocation
}
+ return url.Parse(backend)
+}
+
+func getOrCreateProxy(target *url.URL) *httputil.ReverseProxy {
+ cacheMutex.RLock()
+ proxy, exists := proxyCache[target.String()]
+ cacheMutex.RUnlock()
+ if exists {
+ return proxy
+ }
+
+ cacheMutex.Lock()
+ defer cacheMutex.Unlock()
+
+ if proxy, exists = proxyCache[target.String()]; exists {
+ return proxy
+ }
+
+ proxy = httputil.NewSingleHostReverseProxy(target)
+ proxyCache[target.String()] = proxy
- return url.Parse(backend) // parse and return the backend url
+ return proxy
}
diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go
new file mode 100644
index 0000000..849e705
--- /dev/null
+++ b/internal/middleware/middleware.go
@@ -0,0 +1,38 @@
+package middleware
+
+import (
+ "context"
+ "log"
+ "net/http"
+
+ "github.com/Wacky404/rpserver/internal/auth"
+)
+
+type key string
+
+const claimsKey key = "claims"
+
+func JWT(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ claims, err := auth.VerifyRequest(r)
+ if err != nil {
+ http.Error(w, "Unauthorized", http.StatusUnauthorized)
+ return
+ }
+
+ ctx := context.WithValue(r.Context(), claimsKey, claims)
+ next.ServeHTTP(w, r.WithContext(ctx))
+ })
+}
+
+func Recover(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ defer func() {
+ if err := recover(); err != nil {
+ log.Printf("Recovered from panic: %v", err)
+ http.Error(w, "Internal Server Error", http.StatusInternalServerError)
+ }
+ }()
+ next.ServeHTTP(w, r)
+ })
+}
diff --git a/internal/proxies/proxy.go b/internal/proxies/proxy.go
new file mode 100644
index 0000000..769a888
--- /dev/null
+++ b/internal/proxies/proxy.go
@@ -0,0 +1 @@
+package proxies