Source code for neighborly.data_collection

"""Data collection.

This module contains functionality for collecting and exporting data from a simulation.

Its structure is informed by the data collection layer of Mesa, an agent-based modeling
library written in Python. Here we adapt their functionality to fit the ECS architecture
of the simulation.

"""
from __future__ import annotations

from typing import Any, Iterator, Optional, Sequence

import polars as pl

from neighborly.ecs import SystemGroup


[docs]class DataTablesIterator: """Iterator for DataTables resource.""" __slots__ = ("table_names", "tables", "idx") table_names: tuple[str, ...] """table names to iterate over.""" tables: DataTables """Tables to iterate over.""" idx: int """The current index in the table names tuple.""" def __init__(self, table_names: Sequence[str], tables: DataTables) -> None: self.table_names = tuple(table_names) self.tables = tables self.idx = 0 def __iter__(self) -> Iterator[tuple[str, pl.DataFrame]]: return self def __next__(self) -> tuple[str, pl.DataFrame]: if self.idx < len(self.table_names): name = self.table_names[self.idx] df = self.tables.get_data_frame(name) self.idx += 1 return name, df raise StopIteration
[docs]class DataTables: """A shared resource that collects data from the simulation into tables.""" __slots__ = ("_tables",) _tables: dict[str, dict[str, list[Any]]] """Table names mapped to dicts with column names mapped to data entries.""" def __init__( self, tables: Optional[dict[str, tuple[str, ...]]] = None, ) -> None: """ Parameters ---------- tables Table names mapped to dicts with column names mapped to data entries. """ self._tables = {} # Construct all the tables if tables: for table_name, column_names in tables.items(): self.create_table(table_name, column_names)
[docs] def create_table(self, table_name: str, column_names: tuple[str, ...]) -> None: """Create a new table for data collection. Parameters ---------- table_name The name of the new table. column_names The names of columns within the table. """ new_table: dict[str, list[Any]] = {column: [] for column in column_names} self._tables[table_name] = new_table
[docs] def add_data_row(self, table_name: str, row_data: dict[str, Any]) -> None: """Add a new row of data to a table. Parameters ---------- table_name The table to add the row to. row_data A row of data to add to the table where each dict key is the name of the column. """ if table_name not in self._tables: raise ValueError(f"Could not find table with name: {table_name}") for column in self._tables[table_name]: if column in row_data: self._tables[table_name][column].append(row_data[column]) else: raise KeyError(f"Row data is missing column: {column}")
[docs] def get_data_frame(self, table_name: str) -> pl.DataFrame: """Create a Polars data frame from a table. Parameters ---------- table_name The name of the table to retrieve. Returns ------- pl.DataFrame A polars DataFrame. """ return pl.DataFrame(self._tables[table_name])
def __iter__(self) -> Iterator[tuple[str, pl.DataFrame]]: return DataTablesIterator(list(self._tables.keys()), self)
[docs] def to_dict(self) -> dict[str, Any]: """Serialize the object to a JSON-serializable dict.""" return {**self._tables}
[docs]class DataCollectionSystems(SystemGroup): """System group for collecting data. Any system that collects data during the course of the simulation should belong to this group. """