signcontrol/app.py

184 lines
4.9 KiB
Python

from __future__ import annotations
from contextlib import asynccontextmanager
from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException
from sqlmodel import SQLModel, Session, create_engine, delete, select
import netbrite as nb
from db import (
MessageDB,
NetBriteBase,
NetBriteDB,
NetBritePublic,
ZoneDB,
)
DB_URL = "sqlite:///devices.db"
engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
def get_session():
with Session(engine) as session:
yield session
SessionDep = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
SQLModel.metadata.create_all(engine)
load_devices_from_db()
yield
app = FastAPI(lifespan=lifespan)
active_devices: dict[int, nb.NetBrite] = {}
device_status: dict[int, bool] = {}
# ---------- routes ----------
@app.post("/api/device", response_model=NetBritePublic)
def create_device(device: NetBriteBase, session: SessionDep):
if session.exec(
select(NetBriteDB).where(NetBriteDB.address == device.address)
).first():
raise HTTPException(400, "Device already exists")
db_device = NetBriteDB.model_validate(device)
session.add(db_device)
session.commit()
session.refresh(db_device)
load_device(db_device, session)
return db_device
@app.get("/api/devices", response_model=list[NetBritePublic])
def get_devices(session: SessionDep):
return session.exec(select(NetBriteDB)).all()
@app.post("/api/devices/{device_id}", response_model=NetBritePublic)
def update_device(device_id: int, session: SessionDep):
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
return db_dev
# TODO: implement me
@app.post("/api/devices/{device_id}/reconnect")
def connect_device(device_id: int, session: SessionDep) -> dict[str, bool]:
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
try:
active_devices[device_id] = nb.NetBrite(db_dev.address, db_dev.port)
device_status[device_id] = True
load_zones(session, device_id, active_devices[device_id])
return {"connected": True}
except nb.NetbriteConnectionException as exc:
device_status[device_id] = False
raise HTTPException(400, str(exc))
@app.delete("/api/devices/{device_id}")
def delete_device(device_id: int, session: SessionDep):
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
delete: list[MessageDB | ZoneDB | NetBriteDB] = [db_dev]
for zone in db_dev.zones:
if zone.default_message != None:
delete.append(zone.default_message)
delete.append(zone)
if device_id in active_devices:
active_devices[device_id].sock.close()
del active_devices[device_id]
print(delete)
for i in delete:
session.delete(i)
session.commit()
return 200
# ---------- helper ----------
def load_devices_from_db() -> None:
with Session(engine) as session:
for device in session.exec(select(NetBriteDB)).all():
load_device(device, session)
def load_device(device: NetBriteDB, session: SessionDep):
id = device.id or 0
try:
active_devices[id] = nb.NetBrite(device.address, device.port)
device_status[id] = True
load_zones(session, id, active_devices[id])
except nb.NetbriteConnectionException as exc:
device_status[id] = False
print(f"Could not connect to {device.address}:{device.port}{exc}")
def load_zones(session: Session, device_id: int, net_dev: nb.NetBrite) -> None:
zones: dict[str, nb.Zone] = {}
statement = select(ZoneDB).where(ZoneDB.netbrite_id == device_id)
for zone in session.exec(statement).all():
msg = zone.default_message
default_msg = (
nb.Message(
text=msg.text,
)
if msg
else nb.Message(f"Zone {zone.name}")
)
zones[zone.name] = nb.Zone(
x=zone.x,
y=zone.y,
width=zone.width,
height=zone.height,
scroll_speed=zone.scroll_speed,
pause_duration=zone.pause_duration,
volume=zone.volume,
default_font=zone.default_font,
default_color=zone.default_color,
initial_text=default_msg,
)
if zones:
net_dev.zones(zones)
def create_default_zone(session: Session, device_id: int) -> None:
zone = ZoneDB(
name="0",
x=0,
y=0,
width=150,
height=7,
netbrite_id=device_id,
)
msg = MessageDB(
text="{erase}Welcome",
ttl=60,
)
session.add(msg)
session.add(zone)
session.commit()
session.refresh(zone)
session.refresh(msg)
zone.default_message_id = msg.id
session.commit()