""" Test suite for P1 High Priority Security Features Tests the security enhancements implemented for: - Rate limiting - Security headers - Enhanced authentication (password complexity, account lockout) - Database security (SQL injection prevention) - CSRF protection and request size limits """ import pytest import time from datetime import datetime, timezone from fastapi.testclient import TestClient from sqlalchemy.orm import Session from unittest.mock import patch, MagicMock from app.main import app from app.database.base import get_db from app.models.user import User from app.models.auth import LoginAttempt from app.utils.enhanced_auth import ( PasswordValidator, AccountLockoutManager, SuspiciousActivityDetector, ) from app.utils.database_security import ( SQLSecurityValidator, SecureQueryBuilder, execute_secure_query, ) from app.middleware.rate_limiting import RateLimitStore class TestRateLimiting: """Test rate limiting middleware functionality""" def test_rate_limit_store_basic_functionality(self): """Test basic rate limit store operations""" store = RateLimitStore() # Test initial request - should be allowed allowed, info = store.is_allowed("test_key", 5, 60) assert allowed is True assert info["remaining"] == 4 assert info["limit"] == 5 def test_rate_limit_exceeds_threshold(self): """Test rate limiting when threshold is exceeded""" store = RateLimitStore() # Make requests up to the limit for i in range(5): allowed, info = store.is_allowed("test_key", 5, 60) assert allowed is True # Next request should be rejected allowed, info = store.is_allowed("test_key", 5, 60) assert allowed is False assert info["remaining"] == 0 def test_rate_limit_window_reset(self): """Test rate limit window reset functionality""" store = RateLimitStore() # Fill up the rate limit for i in range(5): store.is_allowed("test_key", 5, 1) # 1 second window # Should be rejected allowed, info = store.is_allowed("test_key", 5, 1) assert allowed is False # Wait for window to reset time.sleep(1.1) # Should be allowed again allowed, info = store.is_allowed("test_key", 5, 1) assert allowed is True class TestSecurityHeaders: """Test security headers middleware""" def test_security_headers_applied(self): """Test that security headers are applied to responses""" client = TestClient(app) response = client.get("/health") # Check for key security headers assert "X-Content-Type-Options" in response.headers assert response.headers["X-Content-Type-Options"] == "nosniff" assert "X-Frame-Options" in response.headers assert response.headers["X-Frame-Options"] == "DENY" assert "X-XSS-Protection" in response.headers assert response.headers["X-XSS-Protection"] == "1; mode=block" def test_csp_header_present(self): """Test Content Security Policy header""" client = TestClient(app) response = client.get("/health") # CSP header should be present if "Content-Security-Policy" in response.headers: csp = response.headers["Content-Security-Policy"] assert "default-src 'self'" in csp assert "object-src 'none'" in csp assert "frame-ancestors 'none'" in csp def test_request_size_limit(self): """Test request size limiting""" client = TestClient(app) # Create a large payload (should be rejected if limit is enforced) large_data = "x" * (101 * 1024 * 1024) # 101MB response = client.post( "/api/auth/login", json={"username": "test", "password": large_data}, headers={"Content-Length": str(len(large_data))} ) # Should be rejected for size assert response.status_code == 413 # Request Entity Too Large class TestPasswordValidation: """Test password strength validation""" def test_weak_passwords_rejected(self): """Test that weak passwords are properly rejected""" weak_passwords = [ "123456", "password", "qwerty", "abc123", "Password", # Missing special chars and numbers "p@ssw0rd", # Too common ] for password in weak_passwords: is_valid, errors = PasswordValidator.validate_password_strength(password) assert is_valid is False assert len(errors) > 0 def test_strong_passwords_accepted(self): """Test that strong passwords are accepted""" strong_passwords = [ "MyStr0ng!P@ssw0rd", "C0mpl3x&S3cur3!", "Adm1n!2024$ecure", "Ungu3ss@bl3P@ssw0rd!", ] for password in strong_passwords: is_valid, errors = PasswordValidator.validate_password_strength(password) assert is_valid is True assert len(errors) == 0 def test_password_strength_scoring(self): """Test password strength scoring""" # Weak password weak_score = PasswordValidator.generate_password_strength_score("123456") assert weak_score < 30 # Strong password strong_score = PasswordValidator.generate_password_strength_score("MyStr0ng!P@ssw0rd") assert strong_score > 70 def test_password_complexity_requirements(self): """Test individual password complexity requirements""" # Test length requirement is_valid, errors = PasswordValidator.validate_password_strength("Abc1!") assert not is_valid assert any("at least" in error for error in errors) # Test uppercase requirement is_valid, errors = PasswordValidator.validate_password_strength("abc123!@#") assert not is_valid assert any("uppercase" in error for error in errors) # Test special character requirement is_valid, errors = PasswordValidator.validate_password_strength("Abc123456") assert not is_valid assert any("special character" in error for error in errors) class TestAccountLockout: """Test account lockout functionality""" @pytest.fixture def mock_db(self): """Mock database session""" mock_db = MagicMock() return mock_db def test_lockout_after_failed_attempts(self, mock_db): """Test account lockout after multiple failed attempts""" # Mock failed login attempts mock_db.query.return_value.filter.return_value.scalar.return_value = 5 mock_db.query.return_value.filter.return_value.order_by.return_value.first.return_value = ( datetime.now(timezone.utc), ) is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser") assert is_locked is True assert unlock_time is not None def test_no_lockout_with_few_attempts(self, mock_db): """Test no lockout with fewer than threshold attempts""" # Mock fewer failed attempts mock_db.query.return_value.filter.return_value.scalar.return_value = 2 is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser") assert is_locked is False assert unlock_time is None def test_lockout_info_retrieval(self, mock_db): """Test lockout information retrieval""" # Mock database responses mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] mock_db.query.return_value.filter.return_value.scalar.return_value = 3 lockout_info = AccountLockoutManager.get_lockout_info(mock_db, "testuser") assert "is_locked" in lockout_info assert "failed_attempts" in lockout_info assert "attempts_remaining" in lockout_info assert "max_attempts" in lockout_info class TestDatabaseSecurity: """Test database security utilities""" def test_sql_injection_detection(self): """Test SQL injection pattern detection""" # Test legitimate query legitimate_query = "SELECT * FROM users WHERE username = :username" issues = SQLSecurityValidator.validate_query_string(legitimate_query) assert len(issues) == 0 # Test malicious queries malicious_queries = [ "SELECT * FROM users WHERE id = 1; DROP TABLE users; --", "SELECT * FROM users WHERE username = 'admin' OR 1=1 --", "SELECT * FROM users UNION SELECT password FROM admin_table", "SELECT * FROM users WHERE id = 1 AND (SELECT COUNT(*) FROM passwords) > 0", ] for query in malicious_queries: issues = SQLSecurityValidator.validate_query_string(query) assert len(issues) > 0 def test_parameter_validation(self): """Test parameter value validation""" # Test legitimate parameters legitimate_params = { "username": "john_doe", "age": 25, "email": "john@example.com", } for param_name, param_value in legitimate_params.items(): issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value) assert len(issues) == 0 # Test malicious parameters malicious_params = { "username": "'; DROP TABLE users; --", "search": "admin' OR 1=1 --", "id": "1 UNION SELECT password FROM users", } for param_name, param_value in malicious_params.items(): issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value) assert len(issues) > 0 def test_safe_query_building(self): """Test safe query building utilities""" from sqlalchemy import Column, String # Mock column for testing mock_column = Column("name", String) # Test safe LIKE clause building like_clause = SecureQueryBuilder.build_like_clause(mock_column, "test_value") assert like_clause is not None # Test safe IN clause building in_clause = SecureQueryBuilder.build_in_clause(mock_column, ["value1", "value2"]) assert in_clause is not None # Test FTS query building fts_query = SecureQueryBuilder.build_fts_query(["term1", "term2"]) assert "term1" in fts_query assert "term2" in fts_query def test_query_validation_with_params(self): """Test complete query validation with parameters""" query = "SELECT * FROM users WHERE username = :username AND age > :min_age" params = { "username": "john_doe", "min_age": 18, } issues = SQLSecurityValidator.validate_query_with_params(query, params) assert len(issues) == 0 # Test with malicious parameters malicious_params = { "username": "'; DROP TABLE users; --", "min_age": "18 OR 1=1", } issues = SQLSecurityValidator.validate_query_with_params(query, malicious_params) assert len(issues) > 0 class TestSuspiciousActivityDetection: """Test suspicious activity detection""" def test_suspicious_ip_detection(self): """Test detection of suspicious IP patterns""" # This would typically require mock data or a test database # For now, test the structure mock_db = MagicMock() mock_db.query.return_value.filter.return_value.all.return_value = [] alerts = SuspiciousActivityDetector.detect_suspicious_patterns(mock_db, 24) assert isinstance(alerts, list) def test_login_suspicion_check(self): """Test individual login suspicion checking""" mock_db = MagicMock() mock_db.query.return_value.filter.return_value.distinct.return_value.all.return_value = [] mock_db.query.return_value.filter.return_value.scalar.return_value = 0 is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious( mock_db, "testuser", "192.168.1.1", "Mozilla/5.0..." ) assert isinstance(is_suspicious, bool) assert isinstance(warnings, list) class TestCSRFProtection: """Test CSRF protection middleware""" def test_csrf_protection_on_post_requests(self): """Test CSRF protection for POST requests""" client = TestClient(app) # POST request without proper origin/referer should be blocked response = client.post( "/api/auth/login", json={"username": "test", "password": "test"}, headers={ "Origin": "https://malicious-site.com", "Host": "localhost:8000" } ) # Should be blocked by CSRF protection # Note: This test might need adjustment based on actual CSRF implementation assert response.status_code in [403, 401, 429] # Could be various security responses def test_csrf_exemption_for_safe_methods(self): """Test that safe HTTP methods are exempt from CSRF protection""" client = TestClient(app) # GET requests should not be subject to CSRF protection response = client.get("/health") assert response.status_code == 200 class TestIntegratedSecurity: """Test integrated security features working together""" def test_enhanced_login_endpoint(self): """Test the enhanced login endpoint with security features""" client = TestClient(app) # Test login with invalid credentials response = client.post( "/api/auth/login", json={"username": "nonexistent", "password": "wrongpassword"} ) assert response.status_code == 401 # Check for security headers in response assert "X-RateLimit-Limit" in response.headers or "WWW-Authenticate" in response.headers @patch('app.utils.enhanced_auth.validate_and_authenticate_user') def test_account_lockout_headers(self, mock_auth): """Test that account lockout information is returned in headers""" client = TestClient(app) # Mock a locked account mock_auth.return_value = (None, ["Account is locked"]) response = client.post( "/api/auth/login", json={"username": "locked_user", "password": "password"} ) assert response.status_code == 401 def test_password_validation_endpoint(self): """Test the password validation endpoint""" client = TestClient(app) # Test weak password response = client.post( "/api/auth/validate-password", json={"password": "123456"} ) assert response.status_code == 200 data = response.json() assert data["is_valid"] is False assert len(data["errors"]) > 0 # Test strong password response = client.post( "/api/auth/validate-password", json={"password": "MyStr0ng!P@ssw0rd"} ) assert response.status_code == 200 data = response.json() assert data["is_valid"] is True assert len(data["errors"]) == 0 @pytest.fixture def client(): """Test client fixture""" return TestClient(app) def test_security_middleware_integration(client): """Test that all security middleware is properly integrated""" response = client.get("/health") # Should have rate limiting headers rate_limit_headers = [h for h in response.headers.keys() if h.startswith("X-RateLimit")] # Should have security headers security_headers = ["X-Content-Type-Options", "X-Frame-Options", "X-XSS-Protection"] for header in security_headers: assert header in response.headers # Should have correlation ID assert "X-Correlation-ID" in response.headers if __name__ == "__main__": # Run specific test for quick validation test_rate_limiting = TestRateLimiting() test_rate_limiting.test_rate_limit_store_basic_functionality() test_password = TestPasswordValidation() test_password.test_weak_passwords_rejected() test_password.test_strong_passwords_accepted() test_db_security = TestDatabaseSecurity() test_db_security.test_sql_injection_detection() print("✅ All P1 security tests passed!")