deneir/trust/EMSOFT24/high_checker.py
2024-05-22 10:36:50 -04:00

55 lines
No EOL
1.9 KiB
Python

import numpy as np
from datetime import datetime, timedelta
from glob import glob
# PARAMETERS
trace_time_length = timedelta(hours=4)
sampling_rate = 20 # Hz
chunk_time_length = timedelta(minutes=10) # chunk of time to get one check value
print(f"Considering {trace_time_length/chunk_time_length} chunks of {chunk_time_length} over a {trace_time_length} trace.")
data_folder = "./data/"
data_selector = "pred_default/preds_*.npy"
# load data
data_filenames = sorted(glob(data_folder+data_selector))
print(f"Listed {len(data_filenames)} traces.")
# define the policy checking function
def checker_long_high(labels):
"""Check a policy on a trace (array of state label).
Produce a single ternary value 1=OK, 0=Unsure, -1=Not OK
Policy: No continuous High load (label=2) for more than 3m
"""
req_L = int(timedelta(minutes=3).total_seconds()*sampling_rate)
#inneficient non-numpy shit, tempormanent solution
for i in range(labels.shape[0]):
if labels[i] == 2 or labels[i] == -1 :
k=1
while i+k < labels.shape[0] and (labels[i+k] == 2 or labels[i+k] == -1):
k+=1
if k == req_L:
if -1 in labels[i:i+k]:
return 0
else:
return -1
return 1
def load_data(filename):
if filename.split(".")[-1] == "npy":
data = np.load(filename)
return data
else:
raise TypeError("plop")
counts = {"1":0, "0":0, "-1":0}
chunk_length = int(chunk_time_length.total_seconds()*sampling_rate)
for filename in data_filenames:
preds = load_data(filename)
compliance = [checker_long_high(preds[i:i+chunk_length]) for i in np.arange(0,preds.shape[0],chunk_length)[:-1]]
for key in counts.keys():
counts[key]+=compliance.count(int(key))
print(f"{filename.split('/')[-1]}: {compliance}")