Classification of device activities (tracked with ActivityWatch) from EEG data.
# Imports
import logging
from typing import Dict
from collections import defaultdict, Counter
from datetime import date
import matplotlib.pyplot as plt
import pandas as pd
import eegclassify
from eegclassify import main, load, clean, features, preprocess, plot, transform
logger = logging.getLogger(__name__)
# Set this to True to run on testing data
simulate_test = True
if simulate_test:
import os
os.environ['PYTEST_CURRENT_TEST'] = "true"
%matplotlib inline
plt.rcParams['figure.dpi'] = 300
plt.rcParams["font.family"] = "serif" # since we're including the figures in serif-typed tex
%%javascript
document.title='erb-thesis/Activity - Jupyter' // Set the document title to be able to track time spent working on the notebook with ActivityWatch
# Load data and save into special variable that won't be overwritten (since loading takes a while)
df_loaded = load.load_labeled_eeg2()
100%|██████████| 1059/1059 [00:02<00:00, 412.04it/s]
# TODO: Split data into sessions to perform out-of-session cross-validation
df_loaded.describe()
/tmp/ipykernel_7092/3530567409.py:3: FutureWarning: Treating datetime data as categorical rather than numeric in `.describe` is deprecated and will be removed in a future version of pandas. Specify `datetime_is_numeric=True` to silence this warning and adopt the future behavior now. df_loaded.describe()
start | stop | class | raw_data | |
---|---|---|---|---|
count | 4 | 4 | 4 | 4 |
unique | 4 | 4 | 2 | 4 |
top | 2020-11-01 13:16:26.806000+00:00 | 2020-11-01 13:17:05.971000+00:00 | Editing->Code | [(2020-11-01 13:16:26.809999872+00:00, -192.87... |
freq | 1 | 1 | 2 | 1 |
first | 2020-11-01 13:16:26.806000+00:00 | 2020-11-01 13:17:05.971000+00:00 | NaN | NaN |
last | 2020-11-01 13:19:52.494000+00:00 | 2020-11-01 13:20:22.553000+00:00 | NaN | NaN |
# Preprocess
df = df_loaded
df = preprocess.split_rows(df, min_duration=5)
#df = clean.clean(df)
df
start | stop | class | raw_data | |
---|---|---|---|---|
0 | 2020-11-01 13:16:26.806000+00:00 | 2020-11-01 13:16:31.806000+00:00 | Editing->Code | [(2020-11-01 13:16:26.809999872+00:00, -192.87... |
1 | 2020-11-01 13:17:18.996000+00:00 | 2020-11-01 13:17:23.996000+00:00 | Editing->Code | [(2020-11-01 13:17:18.997999872+00:00, -1000.0... |
2 | 2020-11-01 13:18:47.374000+00:00 | 2020-11-01 13:18:52.374000+00:00 | Editing->Prose | [(2020-11-01 13:18:47.376000+00:00, -136.23, -... |
3 | 2020-11-01 13:19:52.494000+00:00 | 2020-11-01 13:19:57.494000+00:00 | Editing->Prose | [(2020-11-01 13:19:52.497999872+00:00, 526.855... |
4 | 2020-11-01 13:16:31.806000+00:00 | 2020-11-01 13:16:36.806000+00:00 | Editing->Code | [(2020-11-01 13:16:31.809999872+00:00, 732.422... |
5 | 2020-11-01 13:17:23.996000+00:00 | 2020-11-01 13:17:28.996000+00:00 | Editing->Code | [(2020-11-01 13:17:23.997999872+00:00, -1000.0... |
6 | 2020-11-01 13:18:52.374000+00:00 | 2020-11-01 13:18:57.374000+00:00 | Editing->Prose | [(2020-11-01 13:18:52.376000+00:00, -866.211, ... |
7 | 2020-11-01 13:19:57.494000+00:00 | 2020-11-01 13:20:02.494000+00:00 | Editing->Prose | [(2020-11-01 13:19:57.497999872+00:00, -1000.0... |
8 | 2020-11-01 13:16:36.806000+00:00 | 2020-11-01 13:16:41.806000+00:00 | Editing->Code | [(2020-11-01 13:16:36.809999872+00:00, 962.891... |
9 | 2020-11-01 13:17:28.996000+00:00 | 2020-11-01 13:17:33.996000+00:00 | Editing->Code | [(2020-11-01 13:17:28.997999872+00:00, -1000.0... |
10 | 2020-11-01 13:18:57.374000+00:00 | 2020-11-01 13:19:02.374000+00:00 | Editing->Prose | [(2020-11-01 13:18:57.376000+00:00, -443.848, ... |
11 | 2020-11-01 13:20:02.494000+00:00 | 2020-11-01 13:20:07.494000+00:00 | Editing->Prose | [(2020-11-01 13:20:02.497999872+00:00, -114.74... |
12 | 2020-11-01 13:16:41.806000+00:00 | 2020-11-01 13:16:46.806000+00:00 | Editing->Code | [(2020-11-01 13:16:41.809999872+00:00, 296.875... |
13 | 2020-11-01 13:17:33.996000+00:00 | 2020-11-01 13:17:38.996000+00:00 | Editing->Code | [(2020-11-01 13:17:33.997999872+00:00, -778.80... |
14 | 2020-11-01 13:19:02.374000+00:00 | 2020-11-01 13:19:07.374000+00:00 | Editing->Prose | [(2020-11-01 13:19:02.376000+00:00, -1000.0, -... |
15 | 2020-11-01 13:20:07.494000+00:00 | 2020-11-01 13:20:12.494000+00:00 | Editing->Prose | [(2020-11-01 13:20:07.497999872+00:00, -998.53... |
16 | 2020-11-01 13:16:46.806000+00:00 | 2020-11-01 13:16:51.806000+00:00 | Editing->Code | [(2020-11-01 13:16:46.809999872+00:00, -796.87... |
17 | 2020-11-01 13:17:38.996000+00:00 | 2020-11-01 13:17:43.996000+00:00 | Editing->Code | [(2020-11-01 13:17:38.997999872+00:00, -449.21... |
18 | 2020-11-01 13:17:43.996000+00:00 | 2020-11-01 13:17:49.228000+00:00 | Editing->Code | [(2020-11-01 13:17:43.997999872+00:00, 622.559... |
19 | 2020-11-01 13:19:07.374000+00:00 | 2020-11-01 13:19:12.374000+00:00 | Editing->Prose | [(2020-11-01 13:19:07.376000+00:00, -1000.0, -... |
20 | 2020-11-01 13:19:12.374000+00:00 | 2020-11-01 13:19:21.431000+00:00 | Editing->Prose | [(2020-11-01 13:19:12.376000+00:00, 817.383, -... |
21 | 2020-11-01 13:20:12.494000+00:00 | 2020-11-01 13:20:17.494000+00:00 | Editing->Prose | [(2020-11-01 13:20:12.497999872+00:00, -1000.0... |
22 | 2020-11-01 13:20:17.494000+00:00 | 2020-11-01 13:20:22.553000+00:00 | Editing->Prose | [(2020-11-01 13:20:17.497999872+00:00, 130.859... |
23 | 2020-11-01 13:16:51.806000+00:00 | 2020-11-01 13:16:56.806000+00:00 | Editing->Code | [(2020-11-01 13:16:51.809999872+00:00, -1000.0... |
24 | 2020-11-01 13:16:56.806000+00:00 | 2020-11-01 13:17:05.971000+00:00 | Editing->Code | [(2020-11-01 13:16:56.809999872+00:00, 524.902... |
Counter(df['class'])
Counter({'Editing->Code': 13, 'Editing->Prose': 12})
# NOTE: This says nothing about the actual number of samples, only the number of events
plot.classdistribution(df)
def df_to_seconds_per_day_and_class(df) -> Dict[date, Dict[str, float]]:
all_dates = {d.date() for d in df['start']}
d: Dict[date, Dict[str, float]] = defaultdict(lambda: defaultdict(int))
for date in all_dates:
for idx, entry in df.iterrows():
if date == entry['start'].date():
d[date][entry['class']] += len(entry['raw_data']) / 256
return d
seconds_per_day_and_class = df_to_seconds_per_day_and_class(df)
{date: sum(seconds_per_day_and_class[date].values()) for date in seconds_per_day_and_class.keys()}
{datetime.date(2020, 11, 1): 133.5078125}
if simulate_test:
min_windows = 10
else:
min_windows = 100
combined_df = pd.DataFrame(seconds_per_day_and_class).T
combined_df = combined_df[(combined_df.T.sum() > min_windows * 5)] # at least `5s * min_windows` for each date
combined_df = combined_df.T[(combined_df.sum() > min_windows * 5)].T # at least `5s * min_windows` for each class
combined_df = combined_df.sort_index(axis=0).sort_index(axis=1)
combined_df = combined_df.filter(['Editing->Code', 'Editing->Prose', 'Twitter', 'YouTube'])
combined_df = combined_df.rename({'Editing->Code': 'Programming', 'Editing->Prose': 'Writing'}, axis=1)
combined_df
Programming | Writing | |
---|---|---|
2020-11-01 | 69.394531 | 64.113281 |
(combined_df[::-1]/60).plot.barh()
#plt.label("Date")
plt.xlabel("Minutes of data");
combined_df.sum().plot.bar(rot=0, stacked=True)
plt.xlabel("Category")
plt.ylabel("Seconds of data");
# Can we do PCA on the signal?
logging.getLogger('eegclassify.transform').setLevel(logging.ERROR)
X, y = transform.signal_ndarray(df)
print(X.shape)
#plot.pca(X, y)
(25, 4, 1250)
all_dfs = []
# all classes with decent count
all_dfs += [clean._remove_rare(df, "class", threshold_count=50)]
# Code vs Prose
all_dfs += [clean._select_classes(df, "class", ["Editing->Code", "Editing->Prose"])]
if not simulate_test:
# Code vs Twitter
all_dfs += [clean._select_classes(df, "class", ["Editing->Code", "Twitter"])]
# Code vs YouTube
all_dfs += [clean._select_classes(df, "class", ["Editing->Code", "YouTube"])]
# Prose vs Twitter
all_dfs += [clean._select_classes(df, "class", ["Editing->Prose", "Twitter"])]
# Prose vs YouTube (roughly same class size)
all_dfs += [clean._select_classes(df, "class", ["Editing->Prose", "YouTube"])]
# Twitter vs YouTube
all_dfs += [clean._select_classes(df, "class", ["Twitter", "YouTube"])]
# GitHub PR vs issue
#all_dfs += [clean._select_classes(df, "class", ["GitHub->Issues", "GitHub->Pull request"])]
len_before = len(all_dfs)
all_dfs = [df for df in all_dfs if len(df) > 0]
len_after = len(all_dfs)
logger.warning(f"Removed {len_before - len_after} dfs due to zero length")
Removed 1 dfs due to zero length
# Train
import importlib
importlib.reload(eegclassify.main)
importlib.reload(eegclassify.transform)
for df_train in all_dfs:
print(Counter(df_train['class']))
print(f"Hours of data: {round(len(df_train['class']) * 5 / 60 / 60, 2)}")
try:
main._train_raw(df_train, shuffle=True)
except Exception as e:
# TODO: Fix testing data such that it doesn't err
print("Error while training:", e)
Counter({'Editing->Code': 13, 'Editing->Prose': 12}) Hours of data: 0.03 Error while training: Found array with dim 3. LinearDiscriminantAnalysis expected <= 2.
def clean_empty_data(df):
# FIXME: Where do these empty rows come from? Bad signal quality? Empty in source data?
len_before = len(df)
df = df[df['raw_data'].apply(lambda x: len(x)) > 0]
len_after = len(df)
print(f"Removed {len_before - len_after} rows due to empty raw_data")
return df
for df_train in all_dfs:
print(f"Length of df: {len(df_train)}")
try:
df_train = clean_empty_data(df_train)
print(Counter(df_train["class"]))
main._train_features(df_train)
except Exception as e:
# TODO: Fix testing data such that it doesn't err
logger.exception("Error while training:")
Length of df: 25 Removed 0 rows due to empty raw_data Counter({'Editing->Code': 13, 'Editing->Prose': 12})
/home/runner/work/thesis/thesis/.venv/lib/python3.8/site-packages/outdated/utils.py:14: OutdatedPackageWarning: The package yasa is out of date. Your version is 0.5.1, the latest is 0.6.1. Set the environment variable OUTDATED_IGNORE=1 to disable these warnings. return warn(