From aefa4275d441c8bfa81bd8034c060e8fec857864 Mon Sep 17 00:00:00 2001 From: Jurn Wubben Date: Tue, 26 Aug 2025 14:44:27 +0200 Subject: [PATCH] Implemented basic device controls --- app.py | 233 +++++++++++++++++++++++++++++++++++++-------------------- db.py | 113 +++++++++++++++------------- 2 files changed, 214 insertions(+), 132 deletions(-) diff --git a/app.py b/app.py index 2ed7fe0..f74fd08 100644 --- a/app.py +++ b/app.py @@ -1,34 +1,20 @@ +from __future__ import annotations from contextlib import asynccontextmanager -from nt import device_encoding from typing import Annotated from fastapi import Depends, FastAPI, HTTPException -from sqlmodel import SQLModel, Session, col, create_engine, select -from netbrite import NetBrite, NetbriteConnectionException, Zone +from sqlmodel import SQLModel, Session, create_engine, delete, select + +import netbrite as nb from db import ( - APIAddNetBriteDevice, - APIGetNetBriteDevice, - APIGetZone, - APIMessages, - BaseNetBriteDevice, - BaseZone, + MessageDB, + NetBriteBase, + NetBriteDB, + NetBritePublic, + ZoneDB, ) -sqlite_file_name = "devices.db" -sqlite_url = f"sqlite:///{sqlite_file_name}" - -connect_args = {"check_same_thread": False} -engine = create_engine(sqlite_url, connect_args=connect_args) - - -@asynccontextmanager -async def lifespan(app: FastAPI): - print("creating tables") - create_db_and_tables() - yield - - -def create_db_and_tables(): - SQLModel.metadata.create_all(engine) +DB_URL = "sqlite:///devices.db" +engine = create_engine(DB_URL, connect_args={"check_same_thread": False}) def get_session(): @@ -38,78 +24,161 @@ def get_session(): SessionDep = Annotated[Session, Depends(get_session)] + +@asynccontextmanager +async def lifespan(_: FastAPI): + SQLModel.metadata.create_all(engine) + load_devices_from_db() + yield + + app = FastAPI(lifespan=lifespan) -devices: list[NetBrite] = [] + +active_devices: dict[int, nb.NetBrite] = {} +device_status: dict[int, bool] = {} -@app.post("/api/device") -def create_device(device: APIAddNetBriteDevice, session: SessionDep): - statement = select(BaseNetBriteDevice).where( - col(BaseNetBriteDevice.address) == device.address - ) - result = session.exec(statement) - if result.first() != None: - raise HTTPException(400, "Device is already added") +# ---------- routes ---------- +@app.post("/api/device", response_model=NetBritePublic) +def create_device(device: NetBriteBase, session: SessionDep): + if session.exec( + select(NetBriteDB).where(NetBriteDB.address == device.address) + ).first(): + raise HTTPException(400, "Device already exists") + + db_device = NetBriteDB.model_validate(device) + session.add(db_device) + session.commit() + session.refresh(db_device) + + load_device(db_device, session) + return db_device + + +@app.get("/api/devices", response_model=list[NetBritePublic]) +def get_devices(session: SessionDep): + return session.exec(select(NetBriteDB)).all() + + +@app.post("/api/devices/{device_id}", response_model=NetBritePublic) +def update_device(device_id: int, session: SessionDep): + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") + + return db_dev + + +# TODO: implement me +@app.post("/api/devices/{device_id}/reconnect") +def connect_device(device_id: int, session: SessionDep) -> dict[str, bool]: + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") try: - netbrite_device = NetBrite(device.address, device.port) - devices.append(netbrite_device) - except NetbriteConnectionException as exc: - raise HTTPException(400, "Failed to connect to device") + active_devices[device_id] = nb.NetBrite(db_dev.address, db_dev.port) + device_status[device_id] = True + load_zones(session, device_id, active_devices[device_id]) + return {"connected": True} + except nb.NetbriteConnectionException as exc: + device_status[device_id] = False + raise HTTPException(400, str(exc)) + + +@app.delete("/api/devices/{device_id}") +def delete_device(device_id: int, session: SessionDep): + db_dev = session.get(NetBriteDB, device_id) + if not db_dev: + raise HTTPException(404, "Device not found") + + delete: list[MessageDB | ZoneDB | NetBriteDB] = [db_dev] + for zone in db_dev.zones: + if zone.default_message != None: + delete.append(zone.default_message) + delete.append(zone) + + if device_id in active_devices: + active_devices[device_id].sock.close() + del active_devices[device_id] + + print(delete) + for i in delete: + session.delete(i) - dbdevice = BaseNetBriteDevice(address=device.address, port=device.port) - session.add(dbdevice) session.commit() - session.refresh(dbdevice) - - return device + return 200 -@app.get("/api/devices") -def get_devices(session: SessionDep): - db_devices: dict[int, BaseNetBriteDevice] = {} - - statement = select(BaseZone, BaseNetBriteDevice).join(BaseNetBriteDevice) - results = session.exec(statement) - - for zone, device in results: - if device.id == None: - continue - if not device.id in db_devices: - connected = devices.inde - get_device = APIGetNetBriteDevice(device., zones={}) - db_devices[device.id] = device - - db_devices[device.id].zones_list +# ---------- helper ---------- +def load_devices_from_db() -> None: + with Session(engine) as session: + for device in session.exec(select(NetBriteDB)).all(): + load_device(device, session) - for device in devices: +def load_device(device: NetBriteDB, session: SessionDep): + id = device.id or 0 + try: + active_devices[id] = nb.NetBrite(device.address, device.port) + device_status[id] = True + load_zones(session, id, active_devices[id]) + except nb.NetbriteConnectionException as exc: + device_status[id] = False + print(f"Could not connect to {device.address}:{device.port} — {exc}") - zones: dict[str, APIGetZone] = {} - for index, zone in device.zones_list.items(): - x, y, xend, yend = zone.rect - width = xend - x - height = yend - y - zones[index] = APIGetZone( - x=x, - y=y, - width=width, - height=height, - scroll_speed=zone.scroll_speed, - pause_duration=zone.pause_duration, - volume=zone.volume, - default_font=zone.default_font, - default_color=zone.default_color, - default_message=zone.initial_text, +def load_zones(session: Session, device_id: int, net_dev: nb.NetBrite) -> None: + zones: dict[str, nb.Zone] = {} + statement = select(ZoneDB).where(ZoneDB.netbrite_id == device_id) + + for zone in session.exec(statement).all(): + msg = zone.default_message + + default_msg = ( + nb.Message( + text=msg.text, ) + if msg + else nb.Message(f"Zone {zone.name}") + ) - return devices + zones[zone.name] = nb.Zone( + x=zone.x, + y=zone.y, + width=zone.width, + height=zone.height, + scroll_speed=zone.scroll_speed, + pause_duration=zone.pause_duration, + volume=zone.volume, + default_font=zone.default_font, + default_color=zone.default_color, + initial_text=default_msg, + ) + + if zones: + net_dev.zones(zones) -@app.get("/api/devices/{device_index}") -def get_device(device_index: int): - if device_index > len(devices) - 1: - raise HTTPException(400, "Device doesn't exist") +def create_default_zone(session: Session, device_id: int) -> None: + zone = ZoneDB( + name="0", + x=0, + y=0, + width=150, + height=7, + netbrite_id=device_id, + ) + msg = MessageDB( + text="{erase}Welcome", + ttl=60, + ) + session.add(msg) + session.add(zone) - return devices[device_index] + session.commit() + session.refresh(zone) + session.refresh(msg) + + zone.default_message_id = msg.id + session.commit() diff --git a/db.py b/db.py index f4524d4..0370419 100644 --- a/db.py +++ b/db.py @@ -1,65 +1,78 @@ -from sqlmodel import Field, Session, SQLModel, create_engine, select - -from netbrite import Colors, Fonts, Message, Priorities, ScrollSpeeds +# from __future__ import annotations +from sqlmodel import Field, Relationship, SQLModel +from netbrite import Colors, Fonts, Priorities, ScrollSpeeds, Message -class BaseZone(SQLModel, table=True, echo=True): +# --- Message --- +class MessageBase(SQLModel): + activation_delay: int = 0 + display_delay: int = 0 + display_repeat: int = 0 + priority: Priorities = Priorities.OVERRIDE + text: str = "" + ttl: int = 0 + + +class MessageDB(MessageBase, table=True): id: int | None = Field(default=None, primary_key=True) + zone: "ZoneDB" = Relationship( # pyright: ignore[reportAny] + back_populates="default_message" + ) + +class MessagePublic(MessageBase): + id: int + zone: "ZoneDB" + + +# --- Device --- +class NetBriteBase(SQLModel): + address: str = Field(unique=True, index=True) + port: int = 700 + + +class NetBriteDB(NetBriteBase, table=True): + id: int | None = Field(default=None, primary_key=True) + zones: list["ZoneDB"] = Relationship( # pyright: ignore[reportAny] + back_populates="netbrite" + ) + + +class NetBritePublic(NetBriteBase): + id: int + zones: list["ZoneDB"] + active: bool + + +# --- Zone --- +class ZoneBase(SQLModel): + name: str x: int y: int width: int height: int - scroll_speed: ScrollSpeeds - pause_duration: int - volume: int - default_font: Fonts - default_color: Colors + scroll_speed: ScrollSpeeds = ScrollSpeeds.NORMAL + pause_duration: int = 1000 + volume: int = 4 + default_font: Fonts = Fonts.NORMAL_7 + default_color: Colors = Colors.RED - default_message: int | None = Field(foreign_key="apimessages.id") - netbrite_device_id: int = Field(foreign_key="apinetbritedevice.id") + default_message_id: int | None = Field(default=None, foreign_key="messagedb.id") + netbrite_id: int = Field(default=None, foreign_key="netbritedb.id") -class APIGetZone(SQLModel, table=True, echo=True): +class ZoneDB(ZoneBase, table=True): id: int | None = Field(default=None, primary_key=True) - x: int - y: int - width: int - height: int - scroll_speed: ScrollSpeeds - pause_duration: int - volume: int - default_font: Fonts - default_color: Colors - - default_message: Message | None + default_message: MessageDB | None = Relationship( # pyright: ignore[reportAny] + back_populates="zone" + ) + netbrite: NetBriteDB = Relationship( # pyright: ignore[reportAny] + back_populates="zones" + ) -class APIMessages(SQLModel, table=True, echo=True): - id: int | None = Field(default=None, primary_key=True) - - activation_delay: int - display_delay: int - display_repeat: int - priority: Priorities - text: str - ttl: int - - -class BaseNetBriteDevice(SQLModel, table=True, echo=True): - id: int | None = Field(default=None, primary_key=True) - address: str = Field(unique=True) - port: int = 700 - - -class APIAddNetBriteDevice(SQLModel, table=False): - address: str - port: int = 700 - - -class APIGetNetBriteDevice(SQLModel, table=False): - address: str - port: int = 700 - connected: bool - zones: dict[str, APIGetZone] +class ZonePublic(ZoneBase): + id: int + default_message: MessagePublic + netbrite: NetBritePublic