422 lines
15 KiB
Python
422 lines
15 KiB
Python
"""
|
|
Tests for better_auth_init.py
|
|
|
|
Covers main functionality with mocked I/O and file operations.
|
|
Target: >80% coverage
|
|
"""
|
|
|
|
import sys
|
|
import pytest
|
|
from pathlib import Path
|
|
from unittest.mock import Mock, patch, mock_open, MagicMock
|
|
from io import StringIO
|
|
|
|
# Add parent directory to path
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from better_auth_init import BetterAuthInit, EnvConfig, main
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_project_root(tmp_path):
|
|
"""Create mock project root with package.json."""
|
|
(tmp_path / "package.json").write_text("{}")
|
|
return tmp_path
|
|
|
|
|
|
@pytest.fixture
|
|
def auth_init(mock_project_root):
|
|
"""Create BetterAuthInit instance with mock project root."""
|
|
return BetterAuthInit(project_root=mock_project_root)
|
|
|
|
|
|
class TestBetterAuthInit:
|
|
"""Test BetterAuthInit class."""
|
|
|
|
def test_init_with_project_root(self, mock_project_root):
|
|
"""Test initialization with explicit project root."""
|
|
init = BetterAuthInit(project_root=mock_project_root)
|
|
assert init.project_root == mock_project_root
|
|
assert init.env_config is None
|
|
|
|
def test_find_project_root_success(self, mock_project_root, monkeypatch):
|
|
"""Test finding project root successfully."""
|
|
monkeypatch.chdir(mock_project_root)
|
|
init = BetterAuthInit()
|
|
assert init.project_root == mock_project_root
|
|
|
|
def test_find_project_root_failure(self, tmp_path, monkeypatch):
|
|
"""Test failure to find project root."""
|
|
# Create path without package.json
|
|
no_package_dir = tmp_path / "no-package"
|
|
no_package_dir.mkdir()
|
|
monkeypatch.chdir(no_package_dir)
|
|
|
|
# Mock parent to stop infinite loop
|
|
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
|
|
with pytest.raises(RuntimeError, match="Could not find project root"):
|
|
BetterAuthInit()
|
|
|
|
def test_generate_secret(self):
|
|
"""Test secret generation."""
|
|
secret = BetterAuthInit.generate_secret()
|
|
assert len(secret) == 64 # 32 bytes = 64 hex chars
|
|
assert all(c in "0123456789abcdef" for c in secret)
|
|
|
|
# Test custom length
|
|
secret = BetterAuthInit.generate_secret(length=16)
|
|
assert len(secret) == 32 # 16 bytes = 32 hex chars
|
|
|
|
def test_parse_env_file(self, tmp_path):
|
|
"""Test parsing .env file."""
|
|
env_content = """
|
|
# Comment
|
|
KEY1=value1
|
|
KEY2="value2"
|
|
KEY3='value3'
|
|
INVALID LINE
|
|
KEY4=value=with=equals
|
|
"""
|
|
env_file = tmp_path / ".env"
|
|
env_file.write_text(env_content)
|
|
|
|
result = BetterAuthInit._parse_env_file(env_file)
|
|
|
|
assert result["KEY1"] == "value1"
|
|
assert result["KEY2"] == "value2"
|
|
assert result["KEY3"] == "value3"
|
|
assert result["KEY4"] == "value=with=equals"
|
|
assert "INVALID" not in result
|
|
|
|
def test_parse_env_file_missing(self, tmp_path):
|
|
"""Test parsing missing .env file."""
|
|
result = BetterAuthInit._parse_env_file(tmp_path / "nonexistent.env")
|
|
assert result == {}
|
|
|
|
def test_load_env_files(self, auth_init, mock_project_root):
|
|
"""Test loading environment variables from multiple files."""
|
|
# Create .env files
|
|
claude_env = mock_project_root / ".claude" / ".env"
|
|
claude_env.parent.mkdir(parents=True, exist_ok=True)
|
|
claude_env.write_text("BASE_VAR=base\nOVERRIDE=claude")
|
|
|
|
skills_env = mock_project_root / ".claude" / "skills" / ".env"
|
|
skills_env.parent.mkdir(parents=True, exist_ok=True)
|
|
skills_env.write_text("OVERRIDE=skills\nSKILLS_VAR=skills")
|
|
|
|
# Mock process env (highest priority)
|
|
with patch.dict("os.environ", {"OVERRIDE": "process", "PROCESS_VAR": "process"}):
|
|
result = auth_init._load_env_files()
|
|
|
|
assert result["BASE_VAR"] == "base"
|
|
assert result["SKILLS_VAR"] == "skills"
|
|
assert result["OVERRIDE"] == "process" # Process env wins
|
|
assert result["PROCESS_VAR"] == "process"
|
|
|
|
def test_prompt_direct_db_sqlite(self, auth_init):
|
|
"""Test prompting for SQLite database."""
|
|
with patch("builtins.input", side_effect=["3", "./test.db"]):
|
|
config = auth_init._prompt_direct_db()
|
|
|
|
assert config["type"] == "sqlite"
|
|
assert "better-sqlite3" in config["import"]
|
|
assert "./test.db" in config["config"]
|
|
|
|
def test_prompt_direct_db_postgresql(self, auth_init):
|
|
"""Test prompting for PostgreSQL database."""
|
|
with patch("builtins.input", side_effect=["1", "postgresql://localhost/test"]):
|
|
config = auth_init._prompt_direct_db()
|
|
|
|
assert config["type"] == "postgresql"
|
|
assert "pg" in config["import"]
|
|
assert config["env_var"] == ("DATABASE_URL", "postgresql://localhost/test")
|
|
|
|
def test_prompt_direct_db_mysql(self, auth_init):
|
|
"""Test prompting for MySQL database."""
|
|
with patch("builtins.input", side_effect=["2", "mysql://localhost/test"]):
|
|
config = auth_init._prompt_direct_db()
|
|
|
|
assert config["type"] == "mysql"
|
|
assert "mysql2" in config["import"]
|
|
assert config["env_var"][0] == "DATABASE_URL"
|
|
|
|
def test_prompt_drizzle(self, auth_init):
|
|
"""Test prompting for Drizzle ORM."""
|
|
with patch("builtins.input", return_value="1"):
|
|
config = auth_init._prompt_drizzle()
|
|
|
|
assert config["type"] == "drizzle"
|
|
assert config["provider"] == "pg"
|
|
assert "drizzleAdapter" in config["import"]
|
|
assert "drizzleAdapter" in config["config"]
|
|
|
|
def test_prompt_prisma(self, auth_init):
|
|
"""Test prompting for Prisma."""
|
|
with patch("builtins.input", return_value="2"):
|
|
config = auth_init._prompt_prisma()
|
|
|
|
assert config["type"] == "prisma"
|
|
assert config["provider"] == "mysql"
|
|
assert "prismaAdapter" in config["import"]
|
|
assert "PrismaClient" in config["import"]
|
|
|
|
def test_prompt_kysely(self, auth_init):
|
|
"""Test prompting for Kysely."""
|
|
config = auth_init._prompt_kysely()
|
|
|
|
assert config["type"] == "kysely"
|
|
assert "kyselyAdapter" in config["import"]
|
|
|
|
def test_prompt_mongodb(self, auth_init):
|
|
"""Test prompting for MongoDB."""
|
|
with patch("builtins.input", side_effect=["mongodb://localhost/test", "mydb"]):
|
|
config = auth_init._prompt_mongodb()
|
|
|
|
assert config["type"] == "mongodb"
|
|
assert "mongodbAdapter" in config["import"]
|
|
assert "mydb" in config["config"]
|
|
assert config["env_var"] == ("MONGODB_URI", "mongodb://localhost/test")
|
|
|
|
def test_prompt_database(self, auth_init):
|
|
"""Test database prompting with different choices."""
|
|
# Test valid choice
|
|
with patch("builtins.input", side_effect=["3", "1"]):
|
|
config = auth_init.prompt_database()
|
|
assert config["type"] == "prisma"
|
|
|
|
# Test invalid choice (defaults to direct DB)
|
|
with patch("builtins.input", side_effect=["99", "1", "postgresql://localhost/test"]):
|
|
with patch("builtins.print"):
|
|
config = auth_init.prompt_database()
|
|
assert config["type"] == "postgresql"
|
|
|
|
def test_prompt_auth_methods(self, auth_init):
|
|
"""Test prompting for authentication methods."""
|
|
with patch("builtins.input", return_value="1 2 3 5 8"):
|
|
with patch("builtins.print"):
|
|
methods = auth_init.prompt_auth_methods()
|
|
|
|
assert methods == ["1", "2", "3", "5", "8"]
|
|
|
|
def test_prompt_auth_methods_invalid(self, auth_init):
|
|
"""Test filtering invalid auth method choices."""
|
|
with patch("builtins.input", return_value="1 99 abc 3"):
|
|
with patch("builtins.print"):
|
|
methods = auth_init.prompt_auth_methods()
|
|
|
|
assert methods == ["1", "3"]
|
|
|
|
def test_generate_auth_config_basic(self, auth_init):
|
|
"""Test generating basic auth config."""
|
|
db_config = {
|
|
"import": "import Database from 'better-sqlite3';",
|
|
"config": "database: new Database('./dev.db')"
|
|
}
|
|
auth_methods = ["1"] # Email/password only
|
|
|
|
config = auth_init.generate_auth_config(db_config, auth_methods)
|
|
|
|
assert "import { betterAuth }" in config
|
|
assert "emailAndPassword" in config
|
|
assert "enabled: true" in config
|
|
assert "better-sqlite3" in config
|
|
|
|
def test_generate_auth_config_with_oauth(self, auth_init):
|
|
"""Test generating config with OAuth providers."""
|
|
db_config = {
|
|
"import": "import { Pool } from 'pg';",
|
|
"config": "database: new Pool()"
|
|
}
|
|
auth_methods = ["1", "2", "3", "4"] # Email + GitHub + Google + Discord
|
|
|
|
config = auth_init.generate_auth_config(db_config, auth_methods)
|
|
|
|
assert "socialProviders" in config
|
|
assert "github:" in config
|
|
assert "google:" in config
|
|
assert "discord:" in config
|
|
assert "GITHUB_CLIENT_ID" in config
|
|
assert "GOOGLE_CLIENT_ID" in config
|
|
assert "DISCORD_CLIENT_ID" in config
|
|
|
|
def test_generate_auth_config_with_plugins(self, auth_init):
|
|
"""Test generating config with plugins."""
|
|
db_config = {"import": "", "config": "database: db"}
|
|
auth_methods = ["5", "6", "7", "8"] # 2FA, Passkey, Magic Link, Username
|
|
|
|
config = auth_init.generate_auth_config(db_config, auth_methods)
|
|
|
|
assert "plugins:" in config
|
|
assert "twoFactor" in config
|
|
assert "passkey" in config
|
|
assert "magicLink" in config
|
|
assert "username" in config
|
|
assert "from 'better-auth/plugins'" in config
|
|
|
|
def test_generate_env_file_basic(self, auth_init):
|
|
"""Test generating basic .env file."""
|
|
db_config = {"type": "sqlite"}
|
|
auth_methods = ["1"]
|
|
|
|
env_content = auth_init.generate_env_file(db_config, auth_methods)
|
|
|
|
assert "BETTER_AUTH_SECRET=" in env_content
|
|
assert "BETTER_AUTH_URL=http://localhost:3000" in env_content
|
|
assert len(env_content.split("\n")) >= 2
|
|
|
|
def test_generate_env_file_with_database_url(self, auth_init):
|
|
"""Test generating .env with database URL."""
|
|
db_config = {
|
|
"env_var": ("DATABASE_URL", "postgresql://localhost/test")
|
|
}
|
|
auth_methods = []
|
|
|
|
env_content = auth_init.generate_env_file(db_config, auth_methods)
|
|
|
|
assert "DATABASE_URL=postgresql://localhost/test" in env_content
|
|
|
|
def test_generate_env_file_with_oauth(self, auth_init):
|
|
"""Test generating .env with OAuth credentials."""
|
|
db_config = {}
|
|
auth_methods = ["2", "3", "4"] # GitHub, Google, Discord
|
|
|
|
env_content = auth_init.generate_env_file(db_config, auth_methods)
|
|
|
|
assert "GITHUB_CLIENT_ID=" in env_content
|
|
assert "GITHUB_CLIENT_SECRET=" in env_content
|
|
assert "GOOGLE_CLIENT_ID=" in env_content
|
|
assert "GOOGLE_CLIENT_SECRET=" in env_content
|
|
assert "DISCORD_CLIENT_ID=" in env_content
|
|
assert "DISCORD_CLIENT_SECRET=" in env_content
|
|
|
|
def test_save_files(self, auth_init, mock_project_root):
|
|
"""Test saving configuration files."""
|
|
auth_config = "// auth config"
|
|
env_content = "SECRET=test"
|
|
|
|
with patch("builtins.input", side_effect=["1"]):
|
|
auth_init._save_files(auth_config, env_content)
|
|
|
|
# Check auth.ts was saved
|
|
auth_path = mock_project_root / "lib" / "auth.ts"
|
|
assert auth_path.exists()
|
|
assert auth_path.read_text() == auth_config
|
|
|
|
# Check .env was saved
|
|
env_path = mock_project_root / ".env"
|
|
assert env_path.exists()
|
|
assert env_path.read_text() == env_content
|
|
|
|
def test_save_files_custom_path(self, auth_init, mock_project_root):
|
|
"""Test saving with custom path."""
|
|
auth_config = "// config"
|
|
env_content = "SECRET=test"
|
|
|
|
custom_path = str(mock_project_root / "custom" / "auth.ts")
|
|
with patch("builtins.input", side_effect=["5", custom_path]):
|
|
auth_init._save_files(auth_config, env_content)
|
|
|
|
assert Path(custom_path).exists()
|
|
|
|
def test_save_files_backup_existing_env(self, auth_init, mock_project_root):
|
|
"""Test backing up existing .env file."""
|
|
# Create existing .env
|
|
env_path = mock_project_root / ".env"
|
|
env_path.write_text("OLD_SECRET=old")
|
|
|
|
auth_config = "// config"
|
|
env_content = "NEW_SECRET=new"
|
|
|
|
with patch("builtins.input", return_value="1"):
|
|
auth_init._save_files(auth_config, env_content)
|
|
|
|
# Check backup was created
|
|
backup_path = mock_project_root / ".env.backup"
|
|
assert backup_path.exists()
|
|
assert backup_path.read_text() == "OLD_SECRET=old"
|
|
|
|
# Check new .env
|
|
assert env_path.read_text() == "NEW_SECRET=new"
|
|
|
|
def test_run_full_flow(self, auth_init, mock_project_root):
|
|
"""Test complete run flow."""
|
|
inputs = [
|
|
"1", # Direct DB
|
|
"1", # PostgreSQL
|
|
"postgresql://localhost/test",
|
|
"1 2", # Email + GitHub
|
|
"n" # Don't save
|
|
]
|
|
|
|
with patch("builtins.input", side_effect=inputs):
|
|
with patch("builtins.print"):
|
|
auth_init.run()
|
|
|
|
# Should complete without errors
|
|
# Files not saved because user chose 'n'
|
|
assert not (mock_project_root / "auth.ts").exists()
|
|
|
|
def test_run_save_files(self, auth_init, mock_project_root):
|
|
"""Test run flow with file saving."""
|
|
inputs = [
|
|
"1", # Direct DB
|
|
"3", # SQLite
|
|
"", # Default path
|
|
"1", # Email only
|
|
"y", # Save
|
|
"1" # Save location
|
|
]
|
|
|
|
with patch("builtins.input", side_effect=inputs):
|
|
with patch("builtins.print"):
|
|
auth_init.run()
|
|
|
|
# Check files were created
|
|
assert (mock_project_root / "lib" / "auth.ts").exists()
|
|
assert (mock_project_root / ".env").exists()
|
|
|
|
|
|
class TestMainFunction:
|
|
"""Test main entry point."""
|
|
|
|
def test_main_success(self, tmp_path, monkeypatch):
|
|
"""Test successful main execution."""
|
|
(tmp_path / "package.json").write_text("{}")
|
|
monkeypatch.chdir(tmp_path)
|
|
|
|
inputs = ["1", "3", "", "1", "n"]
|
|
|
|
with patch("builtins.input", side_effect=inputs):
|
|
with patch("builtins.print"):
|
|
exit_code = main()
|
|
|
|
assert exit_code == 0
|
|
|
|
def test_main_keyboard_interrupt(self, tmp_path, monkeypatch):
|
|
"""Test main with keyboard interrupt."""
|
|
(tmp_path / "package.json").write_text("{}")
|
|
monkeypatch.chdir(tmp_path)
|
|
|
|
with patch("builtins.input", side_effect=KeyboardInterrupt()):
|
|
with patch("builtins.print"):
|
|
exit_code = main()
|
|
|
|
assert exit_code == 1
|
|
|
|
def test_main_error(self, tmp_path, monkeypatch):
|
|
"""Test main with error."""
|
|
# No package.json - should fail
|
|
no_package = tmp_path / "no-package"
|
|
no_package.mkdir()
|
|
monkeypatch.chdir(no_package)
|
|
|
|
with patch.object(Path, "parent", new_callable=lambda: property(lambda self: self)):
|
|
with patch("sys.stderr", new_callable=StringIO):
|
|
exit_code = main()
|
|
|
|
assert exit_code == 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v", "--cov=better_auth_init", "--cov-report=term-missing"])
|