Implemented basic device controls

This commit is contained in:
Jurn Wubben 2025-08-26 14:44:27 +02:00
parent aca30088b0
commit aefa4275d4
2 changed files with 214 additions and 132 deletions

233
app.py
View file

@ -1,34 +1,20 @@
from __future__ import annotations
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from nt import device_encoding
from typing import Annotated from typing import Annotated
from fastapi import Depends, FastAPI, HTTPException from fastapi import Depends, FastAPI, HTTPException
from sqlmodel import SQLModel, Session, col, create_engine, select from sqlmodel import SQLModel, Session, create_engine, delete, select
from netbrite import NetBrite, NetbriteConnectionException, Zone
import netbrite as nb
from db import ( from db import (
APIAddNetBriteDevice, MessageDB,
APIGetNetBriteDevice, NetBriteBase,
APIGetZone, NetBriteDB,
APIMessages, NetBritePublic,
BaseNetBriteDevice, ZoneDB,
BaseZone,
) )
sqlite_file_name = "devices.db" DB_URL = "sqlite:///devices.db"
sqlite_url = f"sqlite:///{sqlite_file_name}" engine = create_engine(DB_URL, connect_args={"check_same_thread": False})
connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, connect_args=connect_args)
@asynccontextmanager
async def lifespan(app: FastAPI):
print("creating tables")
create_db_and_tables()
yield
def create_db_and_tables():
SQLModel.metadata.create_all(engine)
def get_session(): def get_session():
@ -38,78 +24,161 @@ def get_session():
SessionDep = Annotated[Session, Depends(get_session)] SessionDep = Annotated[Session, Depends(get_session)]
@asynccontextmanager
async def lifespan(_: FastAPI):
SQLModel.metadata.create_all(engine)
load_devices_from_db()
yield
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
devices: list[NetBrite] = []
active_devices: dict[int, nb.NetBrite] = {}
device_status: dict[int, bool] = {}
@app.post("/api/device") # ---------- routes ----------
def create_device(device: APIAddNetBriteDevice, session: SessionDep): @app.post("/api/device", response_model=NetBritePublic)
statement = select(BaseNetBriteDevice).where( def create_device(device: NetBriteBase, session: SessionDep):
col(BaseNetBriteDevice.address) == device.address if session.exec(
) select(NetBriteDB).where(NetBriteDB.address == device.address)
result = session.exec(statement) ).first():
if result.first() != None: raise HTTPException(400, "Device already exists")
raise HTTPException(400, "Device is already added")
db_device = NetBriteDB.model_validate(device)
session.add(db_device)
session.commit()
session.refresh(db_device)
load_device(db_device, session)
return db_device
@app.get("/api/devices", response_model=list[NetBritePublic])
def get_devices(session: SessionDep):
return session.exec(select(NetBriteDB)).all()
@app.post("/api/devices/{device_id}", response_model=NetBritePublic)
def update_device(device_id: int, session: SessionDep):
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
return db_dev
# TODO: implement me
@app.post("/api/devices/{device_id}/reconnect")
def connect_device(device_id: int, session: SessionDep) -> dict[str, bool]:
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
try: try:
netbrite_device = NetBrite(device.address, device.port) active_devices[device_id] = nb.NetBrite(db_dev.address, db_dev.port)
devices.append(netbrite_device) device_status[device_id] = True
except NetbriteConnectionException as exc: load_zones(session, device_id, active_devices[device_id])
raise HTTPException(400, "Failed to connect to device") return {"connected": True}
except nb.NetbriteConnectionException as exc:
device_status[device_id] = False
raise HTTPException(400, str(exc))
@app.delete("/api/devices/{device_id}")
def delete_device(device_id: int, session: SessionDep):
db_dev = session.get(NetBriteDB, device_id)
if not db_dev:
raise HTTPException(404, "Device not found")
delete: list[MessageDB | ZoneDB | NetBriteDB] = [db_dev]
for zone in db_dev.zones:
if zone.default_message != None:
delete.append(zone.default_message)
delete.append(zone)
if device_id in active_devices:
active_devices[device_id].sock.close()
del active_devices[device_id]
print(delete)
for i in delete:
session.delete(i)
dbdevice = BaseNetBriteDevice(address=device.address, port=device.port)
session.add(dbdevice)
session.commit() session.commit()
session.refresh(dbdevice) return 200
return device
@app.get("/api/devices") # ---------- helper ----------
def get_devices(session: SessionDep): def load_devices_from_db() -> None:
db_devices: dict[int, BaseNetBriteDevice] = {} with Session(engine) as session:
for device in session.exec(select(NetBriteDB)).all():
statement = select(BaseZone, BaseNetBriteDevice).join(BaseNetBriteDevice) load_device(device, session)
results = session.exec(statement)
for zone, device in results:
if device.id == None:
continue
if not device.id in db_devices:
connected = devices.inde
get_device = APIGetNetBriteDevice(device., zones={})
db_devices[device.id] = device
db_devices[device.id].zones_list
for device in devices: def load_device(device: NetBriteDB, session: SessionDep):
id = device.id or 0
try:
active_devices[id] = nb.NetBrite(device.address, device.port)
device_status[id] = True
load_zones(session, id, active_devices[id])
except nb.NetbriteConnectionException as exc:
device_status[id] = False
print(f"Could not connect to {device.address}:{device.port}{exc}")
zones: dict[str, APIGetZone] = {}
for index, zone in device.zones_list.items():
x, y, xend, yend = zone.rect
width = xend - x
height = yend - y
zones[index] = APIGetZone( def load_zones(session: Session, device_id: int, net_dev: nb.NetBrite) -> None:
x=x, zones: dict[str, nb.Zone] = {}
y=y, statement = select(ZoneDB).where(ZoneDB.netbrite_id == device_id)
width=width,
height=height, for zone in session.exec(statement).all():
scroll_speed=zone.scroll_speed, msg = zone.default_message
pause_duration=zone.pause_duration,
volume=zone.volume, default_msg = (
default_font=zone.default_font, nb.Message(
default_color=zone.default_color, text=msg.text,
default_message=zone.initial_text,
) )
if msg
else nb.Message(f"Zone {zone.name}")
)
return devices zones[zone.name] = nb.Zone(
x=zone.x,
y=zone.y,
width=zone.width,
height=zone.height,
scroll_speed=zone.scroll_speed,
pause_duration=zone.pause_duration,
volume=zone.volume,
default_font=zone.default_font,
default_color=zone.default_color,
initial_text=default_msg,
)
if zones:
net_dev.zones(zones)
@app.get("/api/devices/{device_index}") def create_default_zone(session: Session, device_id: int) -> None:
def get_device(device_index: int): zone = ZoneDB(
if device_index > len(devices) - 1: name="0",
raise HTTPException(400, "Device doesn't exist") x=0,
y=0,
width=150,
height=7,
netbrite_id=device_id,
)
msg = MessageDB(
text="{erase}Welcome",
ttl=60,
)
session.add(msg)
session.add(zone)
return devices[device_index] session.commit()
session.refresh(zone)
session.refresh(msg)
zone.default_message_id = msg.id
session.commit()

113
db.py
View file

@ -1,65 +1,78 @@
from sqlmodel import Field, Session, SQLModel, create_engine, select # from __future__ import annotations
from sqlmodel import Field, Relationship, SQLModel
from netbrite import Colors, Fonts, Message, Priorities, ScrollSpeeds from netbrite import Colors, Fonts, Priorities, ScrollSpeeds, Message
class BaseZone(SQLModel, table=True, echo=True): # --- Message ---
class MessageBase(SQLModel):
activation_delay: int = 0
display_delay: int = 0
display_repeat: int = 0
priority: Priorities = Priorities.OVERRIDE
text: str = ""
ttl: int = 0
class MessageDB(MessageBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
zone: "ZoneDB" = Relationship( # pyright: ignore[reportAny]
back_populates="default_message"
)
class MessagePublic(MessageBase):
id: int
zone: "ZoneDB"
# --- Device ---
class NetBriteBase(SQLModel):
address: str = Field(unique=True, index=True)
port: int = 700
class NetBriteDB(NetBriteBase, table=True):
id: int | None = Field(default=None, primary_key=True)
zones: list["ZoneDB"] = Relationship( # pyright: ignore[reportAny]
back_populates="netbrite"
)
class NetBritePublic(NetBriteBase):
id: int
zones: list["ZoneDB"]
active: bool
# --- Zone ---
class ZoneBase(SQLModel):
name: str
x: int x: int
y: int y: int
width: int width: int
height: int height: int
scroll_speed: ScrollSpeeds scroll_speed: ScrollSpeeds = ScrollSpeeds.NORMAL
pause_duration: int pause_duration: int = 1000
volume: int volume: int = 4
default_font: Fonts default_font: Fonts = Fonts.NORMAL_7
default_color: Colors default_color: Colors = Colors.RED
default_message: int | None = Field(foreign_key="apimessages.id") default_message_id: int | None = Field(default=None, foreign_key="messagedb.id")
netbrite_device_id: int = Field(foreign_key="apinetbritedevice.id") netbrite_id: int = Field(default=None, foreign_key="netbritedb.id")
class APIGetZone(SQLModel, table=True, echo=True): class ZoneDB(ZoneBase, table=True):
id: int | None = Field(default=None, primary_key=True) id: int | None = Field(default=None, primary_key=True)
x: int default_message: MessageDB | None = Relationship( # pyright: ignore[reportAny]
y: int back_populates="zone"
width: int )
height: int netbrite: NetBriteDB = Relationship( # pyright: ignore[reportAny]
scroll_speed: ScrollSpeeds back_populates="zones"
pause_duration: int )
volume: int
default_font: Fonts
default_color: Colors
default_message: Message | None
class APIMessages(SQLModel, table=True, echo=True): class ZonePublic(ZoneBase):
id: int | None = Field(default=None, primary_key=True) id: int
default_message: MessagePublic
activation_delay: int netbrite: NetBritePublic
display_delay: int
display_repeat: int
priority: Priorities
text: str
ttl: int
class BaseNetBriteDevice(SQLModel, table=True, echo=True):
id: int | None = Field(default=None, primary_key=True)
address: str = Field(unique=True)
port: int = 700
class APIAddNetBriteDevice(SQLModel, table=False):
address: str
port: int = 700
class APIGetNetBriteDevice(SQLModel, table=False):
address: str
port: int = 700
connected: bool
zones: dict[str, APIGetZone]