Source code for ostatslib.states.features_set

"""
FeaturesSet abstract class module
"""

from abc import ABC
from dataclasses import Field, fields
import numpy as np
from numpy.typing import NDArray

KnownFeaturesList = list[tuple[str, float | int | str]]


[docs]class FeaturesSet(ABC): """ Abstract base class for FeaturesSets """
[docs] def list_known_features(self) -> KnownFeaturesList: """ Lists fields that have values different from default (unknown state attribute) Returns: KnownFeaturesList: list of non-default values """ known_features = [] for field in fields(self): value = getattr(self, field.name) if field.default != value: known_features.append((field.name, value)) return known_features
[docs] def as_gymnasium_space(self) -> dict: """ Features as Gymnasium space Returns: dict: dictionary of features and their Gymnasium spaces """ return {field.name: field.metadata['gym_space'] for field in fields(self)}
[docs] def as_features_dict(self) -> dict: """ Features values as dictionary Returns: dict: dictionary with features values """ return {field.name: self.__get_feature_value(field) for field in fields(self)}
def __array__(self): return np.concatenate([self.__get_feature_value(field) for field in fields(self)]) def __get_feature_value(self, _field: Field) -> NDArray[np.float64]: get_value_fn = _field.metadata['get_value_fn'] field_value = getattr(self, _field.name) if get_value_fn is None: return np.array(field_value).reshape((1,)) feature_value = get_value_fn(field_value) if isinstance(feature_value, np.ndarray): return feature_value return np.array(feature_value).reshape((1,))