Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 13 additions & 11 deletions src/schemaforge/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,15 +136,17 @@ def check(directory: str, canonical: str, type_map_path: str | None) -> None:

def _detect_format(path: str) -> str:
ext = Path(path).suffix.lower()
if ext == ".sql":
return "sql"
if ext == ".prisma":
return "prisma"
if ext in (".ts", ".tsx"):
return "drizzle"
if ext == ".py":
return "django"
if ext in (".json",):
return "typeorm"
return "sql" # default
ext_map = {
".sql": "sql",
".prisma": "prisma",
".ts": "drizzle",
".tsx": "drizzle",
".py": "django",
".json": "json_schema",
".graphql": "graphql",
".gql": "graphql",
".cs": "ef",
".scala": "scala",
}
return ext_map.get(ext, "sql")

2 changes: 1 addition & 1 deletion src/schemaforge/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def create_server() -> Any:
@server.tool(
name="convert",
description="Convert a schema from one format to another. "
"All 9 formats support conversion to and from every other format. "
"All 11 formats support conversion to and from every other format. "
"Returns the converted schema as text.",
)
def convert_tool(
Expand Down
46 changes: 45 additions & 1 deletion tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from schemaforge.cli import main
from schemaforge.cli import _detect_format, main

# ── Helpers ──

Expand Down Expand Up @@ -459,3 +459,47 @@ def test_check_help(self):
assert result.exit_code == 0
assert "Usage:" in result.output
assert "--dir" in result.output


# ═══════════════════════════════════════════════════════════════
# _detect_format
# ═══════════════════════════════════════════════════════════════

class TestDetectFormat:
"""Tests for the private _detect_format helper."""

def test_sql_extension(self):
assert _detect_format("schema.sql") == "sql"

def test_prisma_extension(self):
assert _detect_format("schema.prisma") == "prisma"

def test_drizzle_ts(self):
assert _detect_format("schema.ts") == "drizzle"

def test_drizzle_tsx(self):
assert _detect_format("schema.tsx") == "drizzle"

def test_django_python(self):
assert _detect_format("models.py") == "django"

def test_json_schema(self):
assert _detect_format("schema.json") == "json_schema"

def test_graphql(self):
assert _detect_format("schema.graphql") == "graphql"

def test_graphql_gql(self):
assert _detect_format("schema.gql") == "graphql"

def test_ef_csharp(self):
assert _detect_format("entities.cs") == "ef"

def test_scala(self):
assert _detect_format("models.scala") == "scala"

def test_unknown_extension_defaults_to_sql(self):
assert _detect_format("schema.txt") == "sql"

def test_no_extension_defaults_to_sql(self):
assert _detect_format("schema") == "sql"
Loading