changes
This commit is contained in:
474
tests/test_p1_security_features.py
Normal file
474
tests/test_p1_security_features.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""
|
||||
Test suite for P1 High Priority Security Features
|
||||
|
||||
Tests the security enhancements implemented for:
|
||||
- Rate limiting
|
||||
- Security headers
|
||||
- Enhanced authentication (password complexity, account lockout)
|
||||
- Database security (SQL injection prevention)
|
||||
- CSRF protection and request size limits
|
||||
"""
|
||||
import pytest
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.orm import Session
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from app.main import app
|
||||
from app.database.base import get_db
|
||||
from app.models.user import User
|
||||
from app.models.auth import LoginAttempt
|
||||
from app.utils.enhanced_auth import (
|
||||
PasswordValidator,
|
||||
AccountLockoutManager,
|
||||
SuspiciousActivityDetector,
|
||||
)
|
||||
from app.utils.database_security import (
|
||||
SQLSecurityValidator,
|
||||
SecureQueryBuilder,
|
||||
execute_secure_query,
|
||||
)
|
||||
from app.middleware.rate_limiting import RateLimitStore
|
||||
|
||||
|
||||
class TestRateLimiting:
|
||||
"""Test rate limiting middleware functionality"""
|
||||
|
||||
def test_rate_limit_store_basic_functionality(self):
|
||||
"""Test basic rate limit store operations"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Test initial request - should be allowed
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is True
|
||||
assert info["remaining"] == 4
|
||||
assert info["limit"] == 5
|
||||
|
||||
def test_rate_limit_exceeds_threshold(self):
|
||||
"""Test rate limiting when threshold is exceeded"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Make requests up to the limit
|
||||
for i in range(5):
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is True
|
||||
|
||||
# Next request should be rejected
|
||||
allowed, info = store.is_allowed("test_key", 5, 60)
|
||||
assert allowed is False
|
||||
assert info["remaining"] == 0
|
||||
|
||||
def test_rate_limit_window_reset(self):
|
||||
"""Test rate limit window reset functionality"""
|
||||
store = RateLimitStore()
|
||||
|
||||
# Fill up the rate limit
|
||||
for i in range(5):
|
||||
store.is_allowed("test_key", 5, 1) # 1 second window
|
||||
|
||||
# Should be rejected
|
||||
allowed, info = store.is_allowed("test_key", 5, 1)
|
||||
assert allowed is False
|
||||
|
||||
# Wait for window to reset
|
||||
time.sleep(1.1)
|
||||
|
||||
# Should be allowed again
|
||||
allowed, info = store.is_allowed("test_key", 5, 1)
|
||||
assert allowed is True
|
||||
|
||||
|
||||
class TestSecurityHeaders:
|
||||
"""Test security headers middleware"""
|
||||
|
||||
def test_security_headers_applied(self):
|
||||
"""Test that security headers are applied to responses"""
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
# Check for key security headers
|
||||
assert "X-Content-Type-Options" in response.headers
|
||||
assert response.headers["X-Content-Type-Options"] == "nosniff"
|
||||
|
||||
assert "X-Frame-Options" in response.headers
|
||||
assert response.headers["X-Frame-Options"] == "DENY"
|
||||
|
||||
assert "X-XSS-Protection" in response.headers
|
||||
assert response.headers["X-XSS-Protection"] == "1; mode=block"
|
||||
|
||||
def test_csp_header_present(self):
|
||||
"""Test Content Security Policy header"""
|
||||
client = TestClient(app)
|
||||
|
||||
response = client.get("/health")
|
||||
|
||||
# CSP header should be present
|
||||
if "Content-Security-Policy" in response.headers:
|
||||
csp = response.headers["Content-Security-Policy"]
|
||||
assert "default-src 'self'" in csp
|
||||
assert "object-src 'none'" in csp
|
||||
assert "frame-ancestors 'none'" in csp
|
||||
|
||||
def test_request_size_limit(self):
|
||||
"""Test request size limiting"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Create a large payload (should be rejected if limit is enforced)
|
||||
large_data = "x" * (101 * 1024 * 1024) # 101MB
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "test", "password": large_data},
|
||||
headers={"Content-Length": str(len(large_data))}
|
||||
)
|
||||
|
||||
# Should be rejected for size
|
||||
assert response.status_code == 413 # Request Entity Too Large
|
||||
|
||||
|
||||
class TestPasswordValidation:
|
||||
"""Test password strength validation"""
|
||||
|
||||
def test_weak_passwords_rejected(self):
|
||||
"""Test that weak passwords are properly rejected"""
|
||||
weak_passwords = [
|
||||
"123456",
|
||||
"password",
|
||||
"qwerty",
|
||||
"abc123",
|
||||
"Password", # Missing special chars and numbers
|
||||
"p@ssw0rd", # Too common
|
||||
]
|
||||
|
||||
for password in weak_passwords:
|
||||
is_valid, errors = PasswordValidator.validate_password_strength(password)
|
||||
assert is_valid is False
|
||||
assert len(errors) > 0
|
||||
|
||||
def test_strong_passwords_accepted(self):
|
||||
"""Test that strong passwords are accepted"""
|
||||
strong_passwords = [
|
||||
"MyStr0ng!P@ssw0rd",
|
||||
"C0mpl3x&S3cur3!",
|
||||
"Adm1n!2024$ecure",
|
||||
"Ungu3ss@bl3P@ssw0rd!",
|
||||
]
|
||||
|
||||
for password in strong_passwords:
|
||||
is_valid, errors = PasswordValidator.validate_password_strength(password)
|
||||
assert is_valid is True
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_password_strength_scoring(self):
|
||||
"""Test password strength scoring"""
|
||||
# Weak password
|
||||
weak_score = PasswordValidator.generate_password_strength_score("123456")
|
||||
assert weak_score < 30
|
||||
|
||||
# Strong password
|
||||
strong_score = PasswordValidator.generate_password_strength_score("MyStr0ng!P@ssw0rd")
|
||||
assert strong_score > 70
|
||||
|
||||
def test_password_complexity_requirements(self):
|
||||
"""Test individual password complexity requirements"""
|
||||
# Test length requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("Abc1!")
|
||||
assert not is_valid
|
||||
assert any("at least" in error for error in errors)
|
||||
|
||||
# Test uppercase requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("abc123!@#")
|
||||
assert not is_valid
|
||||
assert any("uppercase" in error for error in errors)
|
||||
|
||||
# Test special character requirement
|
||||
is_valid, errors = PasswordValidator.validate_password_strength("Abc123456")
|
||||
assert not is_valid
|
||||
assert any("special character" in error for error in errors)
|
||||
|
||||
|
||||
class TestAccountLockout:
|
||||
"""Test account lockout functionality"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(self):
|
||||
"""Mock database session"""
|
||||
mock_db = MagicMock()
|
||||
return mock_db
|
||||
|
||||
def test_lockout_after_failed_attempts(self, mock_db):
|
||||
"""Test account lockout after multiple failed attempts"""
|
||||
# Mock failed login attempts
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 5
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.first.return_value = (
|
||||
datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
|
||||
assert is_locked is True
|
||||
assert unlock_time is not None
|
||||
|
||||
def test_no_lockout_with_few_attempts(self, mock_db):
|
||||
"""Test no lockout with fewer than threshold attempts"""
|
||||
# Mock fewer failed attempts
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 2
|
||||
|
||||
is_locked, unlock_time = AccountLockoutManager.is_account_locked(mock_db, "testuser")
|
||||
assert is_locked is False
|
||||
assert unlock_time is None
|
||||
|
||||
def test_lockout_info_retrieval(self, mock_db):
|
||||
"""Test lockout information retrieval"""
|
||||
# Mock database responses
|
||||
mock_db.query.return_value.filter.return_value.order_by.return_value.all.return_value = []
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 3
|
||||
|
||||
lockout_info = AccountLockoutManager.get_lockout_info(mock_db, "testuser")
|
||||
|
||||
assert "is_locked" in lockout_info
|
||||
assert "failed_attempts" in lockout_info
|
||||
assert "attempts_remaining" in lockout_info
|
||||
assert "max_attempts" in lockout_info
|
||||
|
||||
|
||||
class TestDatabaseSecurity:
|
||||
"""Test database security utilities"""
|
||||
|
||||
def test_sql_injection_detection(self):
|
||||
"""Test SQL injection pattern detection"""
|
||||
# Test legitimate query
|
||||
legitimate_query = "SELECT * FROM users WHERE username = :username"
|
||||
issues = SQLSecurityValidator.validate_query_string(legitimate_query)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test malicious queries
|
||||
malicious_queries = [
|
||||
"SELECT * FROM users WHERE id = 1; DROP TABLE users; --",
|
||||
"SELECT * FROM users WHERE username = 'admin' OR 1=1 --",
|
||||
"SELECT * FROM users UNION SELECT password FROM admin_table",
|
||||
"SELECT * FROM users WHERE id = 1 AND (SELECT COUNT(*) FROM passwords) > 0",
|
||||
]
|
||||
|
||||
for query in malicious_queries:
|
||||
issues = SQLSecurityValidator.validate_query_string(query)
|
||||
assert len(issues) > 0
|
||||
|
||||
def test_parameter_validation(self):
|
||||
"""Test parameter value validation"""
|
||||
# Test legitimate parameters
|
||||
legitimate_params = {
|
||||
"username": "john_doe",
|
||||
"age": 25,
|
||||
"email": "john@example.com",
|
||||
}
|
||||
|
||||
for param_name, param_value in legitimate_params.items():
|
||||
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test malicious parameters
|
||||
malicious_params = {
|
||||
"username": "'; DROP TABLE users; --",
|
||||
"search": "admin' OR 1=1 --",
|
||||
"id": "1 UNION SELECT password FROM users",
|
||||
}
|
||||
|
||||
for param_name, param_value in malicious_params.items():
|
||||
issues = SQLSecurityValidator.validate_parameter_value(param_name, param_value)
|
||||
assert len(issues) > 0
|
||||
|
||||
def test_safe_query_building(self):
|
||||
"""Test safe query building utilities"""
|
||||
from sqlalchemy import Column, String
|
||||
|
||||
# Mock column for testing
|
||||
mock_column = Column("name", String)
|
||||
|
||||
# Test safe LIKE clause building
|
||||
like_clause = SecureQueryBuilder.build_like_clause(mock_column, "test_value")
|
||||
assert like_clause is not None
|
||||
|
||||
# Test safe IN clause building
|
||||
in_clause = SecureQueryBuilder.build_in_clause(mock_column, ["value1", "value2"])
|
||||
assert in_clause is not None
|
||||
|
||||
# Test FTS query building
|
||||
fts_query = SecureQueryBuilder.build_fts_query(["term1", "term2"])
|
||||
assert "term1" in fts_query
|
||||
assert "term2" in fts_query
|
||||
|
||||
def test_query_validation_with_params(self):
|
||||
"""Test complete query validation with parameters"""
|
||||
query = "SELECT * FROM users WHERE username = :username AND age > :min_age"
|
||||
params = {
|
||||
"username": "john_doe",
|
||||
"min_age": 18,
|
||||
}
|
||||
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, params)
|
||||
assert len(issues) == 0
|
||||
|
||||
# Test with malicious parameters
|
||||
malicious_params = {
|
||||
"username": "'; DROP TABLE users; --",
|
||||
"min_age": "18 OR 1=1",
|
||||
}
|
||||
|
||||
issues = SQLSecurityValidator.validate_query_with_params(query, malicious_params)
|
||||
assert len(issues) > 0
|
||||
|
||||
|
||||
class TestSuspiciousActivityDetection:
|
||||
"""Test suspicious activity detection"""
|
||||
|
||||
def test_suspicious_ip_detection(self):
|
||||
"""Test detection of suspicious IP patterns"""
|
||||
# This would typically require mock data or a test database
|
||||
# For now, test the structure
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.all.return_value = []
|
||||
|
||||
alerts = SuspiciousActivityDetector.detect_suspicious_patterns(mock_db, 24)
|
||||
assert isinstance(alerts, list)
|
||||
|
||||
def test_login_suspicion_check(self):
|
||||
"""Test individual login suspicion checking"""
|
||||
mock_db = MagicMock()
|
||||
mock_db.query.return_value.filter.return_value.distinct.return_value.all.return_value = []
|
||||
mock_db.query.return_value.filter.return_value.scalar.return_value = 0
|
||||
|
||||
is_suspicious, warnings = SuspiciousActivityDetector.is_login_suspicious(
|
||||
mock_db, "testuser", "192.168.1.1", "Mozilla/5.0..."
|
||||
)
|
||||
|
||||
assert isinstance(is_suspicious, bool)
|
||||
assert isinstance(warnings, list)
|
||||
|
||||
|
||||
class TestCSRFProtection:
|
||||
"""Test CSRF protection middleware"""
|
||||
|
||||
def test_csrf_protection_on_post_requests(self):
|
||||
"""Test CSRF protection for POST requests"""
|
||||
client = TestClient(app)
|
||||
|
||||
# POST request without proper origin/referer should be blocked
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "test", "password": "test"},
|
||||
headers={
|
||||
"Origin": "https://malicious-site.com",
|
||||
"Host": "localhost:8000"
|
||||
}
|
||||
)
|
||||
|
||||
# Should be blocked by CSRF protection
|
||||
# Note: This test might need adjustment based on actual CSRF implementation
|
||||
assert response.status_code in [403, 401, 429] # Could be various security responses
|
||||
|
||||
def test_csrf_exemption_for_safe_methods(self):
|
||||
"""Test that safe HTTP methods are exempt from CSRF protection"""
|
||||
client = TestClient(app)
|
||||
|
||||
# GET requests should not be subject to CSRF protection
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
class TestIntegratedSecurity:
|
||||
"""Test integrated security features working together"""
|
||||
|
||||
def test_enhanced_login_endpoint(self):
|
||||
"""Test the enhanced login endpoint with security features"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Test login with invalid credentials
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "nonexistent", "password": "wrongpassword"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
# Check for security headers in response
|
||||
assert "X-RateLimit-Limit" in response.headers or "WWW-Authenticate" in response.headers
|
||||
|
||||
@patch('app.utils.enhanced_auth.validate_and_authenticate_user')
|
||||
def test_account_lockout_headers(self, mock_auth):
|
||||
"""Test that account lockout information is returned in headers"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Mock a locked account
|
||||
mock_auth.return_value = (None, ["Account is locked"])
|
||||
|
||||
response = client.post(
|
||||
"/api/auth/login",
|
||||
json={"username": "locked_user", "password": "password"}
|
||||
)
|
||||
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_password_validation_endpoint(self):
|
||||
"""Test the password validation endpoint"""
|
||||
client = TestClient(app)
|
||||
|
||||
# Test weak password
|
||||
response = client.post(
|
||||
"/api/auth/validate-password",
|
||||
json={"password": "123456"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_valid"] is False
|
||||
assert len(data["errors"]) > 0
|
||||
|
||||
# Test strong password
|
||||
response = client.post(
|
||||
"/api/auth/validate-password",
|
||||
json={"password": "MyStr0ng!P@ssw0rd"}
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["is_valid"] is True
|
||||
assert len(data["errors"]) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Test client fixture"""
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def test_security_middleware_integration(client):
|
||||
"""Test that all security middleware is properly integrated"""
|
||||
response = client.get("/health")
|
||||
|
||||
# Should have rate limiting headers
|
||||
rate_limit_headers = [h for h in response.headers.keys() if h.startswith("X-RateLimit")]
|
||||
|
||||
# Should have security headers
|
||||
security_headers = ["X-Content-Type-Options", "X-Frame-Options", "X-XSS-Protection"]
|
||||
for header in security_headers:
|
||||
assert header in response.headers
|
||||
|
||||
# Should have correlation ID
|
||||
assert "X-Correlation-ID" in response.headers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run specific test for quick validation
|
||||
test_rate_limiting = TestRateLimiting()
|
||||
test_rate_limiting.test_rate_limit_store_basic_functionality()
|
||||
|
||||
test_password = TestPasswordValidation()
|
||||
test_password.test_weak_passwords_rejected()
|
||||
test_password.test_strong_passwords_accepted()
|
||||
|
||||
test_db_security = TestDatabaseSecurity()
|
||||
test_db_security.test_sql_injection_detection()
|
||||
|
||||
print("✅ All P1 security tests passed!")
|
||||
Reference in New Issue
Block a user