DataJoint Element for Motion Sequencing with Keypoint-MoSeq¶
Open-source Data Pipeline for Motion Sequencing in Neurophysiology¶
Welcome to the tutorial for the DataJoint Element for motion sequencing analysis. This tutorial aims to provide a comprehensive understanding of the open-source data pipeline by element-moseq
.
The package is designed to seamlessly integrate the PCA fitting, model fitting through initialization, fitting an AR-HMM, and fitting the full keypoint-SLDS model into a data pipeline and streamline model and video management using DataJoint.
By the end of this tutorial, you will have a clear grasp of how to set up and integrate the Element MoSeq
into your specific research projects and your lab.
Prerequisites¶
Please see the datajoint tutorials GitHub repository proceeding. A basic understanding of the following DataJoint concepts will be beneficial to your understanding of this tutorial:
- The
Imported
andComputed
tables types indatajoint-python
. - The functionality of the
.populate()
method.
Tutorial Overview¶
- Setup
- Activate the DataJoint pipeline
- Insert example data into subject and session tables
- Insert the keypoint data from the pose estimation and the body parts in the DataJoint pipeline
- Fit a PCA model to aligned and centered keypoint coordinates and select the latent dimension
- Train the AR-HMM and Keypoint-SLDS Models
- Run the inference task and visualize the results
Setup¶
This tutorial loads the keypoint data extracted by DeepLabCut of a single freely moving mouse in an open-field environment. The open-source data is used as an example in the Keypoint-MoSeq collab tutorial.
The goal is to link this point tracking to pose dynamics by identifying its behavioral modules ("syllables") without human supervision. The modeling results are stored as a .h5
file and a subdirectory of .csv
files that contain the following information:
- Behavior modules as "syllables": the syllable label assigned to each frame (i.e. the state indexes assigned by the model)
- Centroid and heading in each frame, as estimated by the model, that capture the animal's overall position in allocentric coordinates
- Latent state: low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, and are modified to reflect the pose dynamics and noise estimates inferred by the model.
The results of this Element example can be combined with other modalities to create a complete customizable data pipeline for your specific lab or study. For instance, you can combine element-moseq
with element-deeplabcut
and element-calcium-imaging
to characterize the neural activity along with natural sub-second rhythmicity in mouse movement.
Steps to Run the Element-MoSeq¶
The input data for this data pipeline is as follows:
- A DeepLabCut (DLC) project folder with its configuration file as
.yaml
file, video set as.mp4
, and keypoint tracking as.h5
files. - Selection of the anterior, posterior, and use bodyparts for the model fitting.
This tutorial includes the keypoints example data in example_data/inbox/dlc_project
.
Let's start this tutorial by importing the packages necessary to run the data pipeline.
import os
if os.path.basename(os.getcwd()) == "notebooks":
os.chdir("..")
import datajoint as dj
from pathlib import Path
import numpy as np
from datetime import datetime
from element_moseq.moseq_infer import get_kpms_processed_data_dir
If the tutorial is run in Codespaces, a private, local database server is created and made available for you. This is where we will insert and store our processed results.
Let's connect to the database server.
dj.conn()
[2024-08-17 00:40:37,105][INFO]: Connecting root@fakeservices.datajoint.io:3306 INFO:datajoint:Connecting root@fakeservices.datajoint.io:3306 [2024-08-17 00:40:37,112][INFO]: Connected root@fakeservices.datajoint.io:3306 INFO:datajoint:Connected root@fakeservices.datajoint.io:3306
DataJoint connection (connected) root@fakeservices.datajoint.io:3306
Activate the DataJoint pipeline¶
This tutorial presumes that the element-moseq
has been pre-configured and instantiated, with the database linked downstream to pre-existing subject
and session
tables. Please refer to the tutorial_pipeline.py
for the source code.
Now, we will proceed to import the essential schemas required to construct this data pipeline, with particular attention to the primary components: moseq_train
and moseq_infer
.
from tutorial_pipeline import lab, subject, session, moseq_train, moseq_infer
[2024-08-17 00:40:41,878][WARNING]: lab.Project and related tables will be removed in a future version of Element Lab. Please use the project schema. WARNING:datajoint:lab.Project and related tables will be removed in a future version of Element Lab. Please use the project schema.
We can represent the tables in the moseq_train
and moseq_infer
schemas as well as some of the upstream dependencies to session
and subject
schemas as a diagram.
(
dj.Diagram(subject.Subject)
+ dj.Diagram(session.Session)
+ dj.Diagram(moseq_train)
+ dj.Diagram(moseq_infer)
)
As evident from the diagram, this data pipeline encompasses several tables associated with different keypoint-MoSeq components like pca, pre-fitting of AR-HMM, and full fitting of the model. A few tables, such as subject.Subject
or session.Session
, while important for a complete pipeline, fall outside the scope of the element-moseq
tutorial, and will therefore, not be explored extensively here. The primary focus of this tutorial will be on the moseq_train
and moseq_infer
schemas.
dj.Diagram(moseq_train) + dj.Diagram(moseq_infer)
Insert example data into subject and session tables¶
Let's delve into the subject.Subject
and session.Session
tables and include some example data.
subject.Subject()
subject | subject_nickname | sex | subject_birth_date | subject_description |
---|---|---|---|---|
Total: 0
Add a new entry for a subject in the Subject
table:
subject.Subject.insert1(
dict(
subject="subject1",
sex="F",
subject_birth_date="2024-01-01",
subject_description="test subject",
),
skip_duplicates=True,
)
Create session keys and input them into the Session
table:
# Definition of the dictionary named "session_keys"
session_keys = [
dict(subject="subject1", session_datetime="2024-03-15 14:04:22"),
dict(subject="subject1", session_datetime="2024-03-16 14:43:10"),
]
# Insert this dictionary in the Session table
session.Session.insert(session_keys, skip_duplicates=True)
Confirm the inserted data:
session.Session()
subject | session_datetime | session_id |
---|---|---|
subject1 | 2024-03-15 14:04:22 | None |
subject1 | 2024-03-16 14:43:10 | None |
Total: 2
Let's define a key
to use throughout the notebook:
session_key = dict(subject="subject1", session_datetime="2024-03-15 14:04:22")
session_key
{'subject': 'subject1', 'session_datetime': '2024-03-15 14:04:22'}
Insert the keypoint data from the pose estimation and the body parts in the DataJoint pipeline¶
The PoseEstimationMethod
table contains the pose estimation methods and file formats supported by the keypoint loader of keypoint-moseq
package. In this tutorial, the keypoint input data are .h5
files that have been obtained using DeepLabCut
.
moseq_infer.PoseEstimationMethod()
pose_estimation_method Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) | pose_estimation_desc Optional. Pose estimation method description with the supported formats. |
---|---|
anipose | `.csv` files generated by anipose analysis |
deeplabcut | `.csv` and `.h5/.hdf5` files generated by DeepLabcut analysis |
facemap | `.h5` files generated by Facemap analysis |
nwb | `.nwb` files with Neurodata Without Borders (NWB) format |
sleap | `.slp` and `.h5/.hdf5` files generated by SLEAP analysis |
sleap-anipose | `.h5/.hdf5` files generated by sleap-anipose analysis |
Total: 6
Insert keypoint input metadata into the KeypointSet
table:
moseq_train.KeypointSet.insert1(
{
"kpset_id": 1,
"pose_estimation_method": "deeplabcut",
"kpset_dir": "dlc_project",
"kpset_desc": "Example keypoint set",
}
)
moseq_train.KeypointSet()
kpset_id Unique ID for each keypoint set | pose_estimation_method Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) | kpset_dir Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory | kpset_desc Optional. User-entered description |
---|---|---|---|
1 | deeplabcut | dlc_project | Example keypoint set |
Total: 1
Add the video files in KeypointSet.VideoFile
that will be used to fit the model:
videos_path = [
"dlc_project/videos/21_12_10_def6a_3.top.ir.mp4",
"dlc_project/videos/22_04_26_cage4_1_1.top.ir.mp4",
"dlc_project/videos/21_12_10_def6a_1_1.top.ir.mp4",
"dlc_project/videos/22_27_04_cage4_mouse2_0.top.ir.mp4",
"dlc_project/videos/22_04_26_cage4_0.top.ir.mp4",
"dlc_project/videos/21_11_8_one_mouse.top.ir.Mp4",
"dlc_project/videos/21_12_2_def6b_2.top.ir.mp4",
"dlc_project/videos/21_12_10_def6b_3.top.ir.Mp4",
"dlc_project/videos/22_04_26_cage4_0_2.top.ir.mp4",
"dlc_project/videos/21_12_2_def6a_1.top.ir.mp4",
]
# Insert the video files in the `VideoFile` table
moseq_train.KeypointSet.VideoFile.insert(
(
{"kpset_id": 1, "video_id": v_idx, "video_path": f}
for v_idx, f in enumerate(videos_path)
),
skip_duplicates=True,
)
moseq_train.KeypointSet.VideoFile()
kpset_id Unique ID for each keypoint set | video_id Unique ID for each video corresponding to each keypoint data file, relative to root data directory | video_path Filepath of each video from which the keypoints are derived, relative to root data directory |
---|---|---|
1 | 0 | dlc_project/videos/21_12_10_def6a_3.top.ir.mp4 |
1 | 1 | dlc_project/videos/22_04_26_cage4_1_1.top.ir.mp4 |
1 | 2 | dlc_project/videos/21_12_10_def6a_1_1.top.ir.mp4 |
1 | 3 | dlc_project/videos/22_27_04_cage4_mouse2_0.top.ir.mp4 |
1 | 4 | dlc_project/videos/22_04_26_cage4_0.top.ir.mp4 |
1 | 5 | dlc_project/videos/21_11_8_one_mouse.top.ir.Mp4 |
1 | 6 | dlc_project/videos/21_12_2_def6b_2.top.ir.mp4 |
1 | 7 | dlc_project/videos/21_12_10_def6b_3.top.ir.Mp4 |
1 | 8 | dlc_project/videos/22_04_26_cage4_0_2.top.ir.mp4 |
1 | 9 | dlc_project/videos/21_12_2_def6a_1.top.ir.mp4 |
Total: 10
Now, let's insert the body parts to use in the analysis:
pca_task_key = {"kpset_id": 1, "bodyparts_id": 1}
moseq_train.Bodyparts.insert1(
{
**pca_task_key,
"anterior_bodyparts": ["nose"],
"posterior_bodyparts": ["spine4"],
"use_bodyparts": [
"spine4",
"spine3",
"spine2",
"spine1",
"head",
"nose",
"right ear",
"left ear",
],
"bodyparts_desc": "Example of KPMS bodyparts extracted with DLC 2.3.9",
}
)
moseq_train.Bodyparts()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | anterior_bodyparts List of strings of anterior bodyparts | posterior_bodyparts List of strings of posterior bodyparts | use_bodyparts List of strings of bodyparts to be used | bodyparts_desc Optional. User-entered description |
---|---|---|---|---|---|
1 | 1 | =BLOB= | =BLOB= | =BLOB= | Example of KPMS bodyparts extracted with DLC 2.3.9 |
Total: 1
Fit a PCA model to aligned and centered keypoint coordinates and select the latent dimension¶
To conduct model fitting for keypoint-MoSeq, both a PCA model and the latent dimension of the pose trajectory are necessary.
dj.Diagram(moseq_train)
The PCATask
table serves the purpose of specifying the PCA task.
moseq_train.PCATask()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | kpms_project_output_dir Keypoint-MoSeq project output directory, relative to root data directory | task_mode Trigger or load the task |
---|---|---|---|
Total: 0
Defining and inserting a PCA task requires:
- Select a keypoint set
- Select the body parts to use
- Specify the output directory for the KPMS project
moseq_train.PCATask.insert1(
{
**pca_task_key,
"kpms_project_output_dir": "kpms_project_tutorial",
"task_mode": "load",
}
)
moseq_train.PCATask()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | kpms_project_output_dir Keypoint-MoSeq project output directory, relative to root data directory | task_mode Trigger or load the task |
---|---|---|---|
1 | 1 | kpms_project_tutorial | load |
Total: 1
Before running the PCA fitting, the keypoint detections and body parts need to be formatted. The resulting coordinates and confidences scores will be used to format the data for modeling.
moseq_train.PCAPrep()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | coordinates Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) | confidences Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) | formatted_bodyparts List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. | average_frame_rate Average frame rate of the videos for model training | frame_rates List of the frame rates of the videos for model training |
---|---|---|---|---|---|---|
Total: 0
Populate the PCAPrep
table will:
- Create the output directory, if it does not exist, with the kpms default
config.yml
file that contains the default values from the pose estimation - Generate a copy as
dj_config.yml
and update it with both the video directory and the bodyparts - Create and store the keypoint coordinates and confidences scores to format the data for the PCA fitting
- Calculate the average frame rate of the videoset chosen to train the model. This will be useful to calculate the kappa value in the next step.
moseq_train.PCAPrep.populate(pca_task_key)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.) /home/vscode/.local/lib/python3.9/site-packages/keypoint_moseq/analysis.py:20: UserWarning: Using Panel interactively in VSCode notebooks requires the jupyter_bokeh package to be installed. You can install it with: pip install jupyter_bokeh or: conda install jupyter_bokeh and try again.
Loading keypoints: 100%|████████████████| 10/10 [00:08<00:00, 1.12it/s]
moseq_train.PCAPrep()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | coordinates Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) | confidences Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) | formatted_bodyparts List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. | average_frame_rate Average frame rate of the videos for model training | frame_rates List of the frame rates of the videos for model training |
---|---|---|---|---|---|---|
1 | 1 | =BLOB= | =BLOB= | =BLOB= | 30.0 | =BLOB= |
Total: 1
The PCAFit
computation will format the aligned and centered keypoint coordinates, fit a PCA model, and save it as pca.p
file in the output directory.
moseq_train.PCAFit.populate(pca_task_key)
moseq_train.PCAFit()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | pca_fit_time datetime of the PCA fitting analysis |
---|---|---|
1 | 1 | None |
Total: 1
However, we still need to determine the specific dimension of the pose trajectory to utilize for fitting the keypoint-MoSeq model. A helpful guideline is to consider the number of dimensions required to explain 90% of the variance, or a maximum of 10 dimensions, whichever is lower.
The computation of LatentDimension
will automatically identify the components that explain 90% of the variance, aiding the user in making the final decision regarding an appropriate latent dimension for model fitting.
moseq_train.LatentDimension.populate(pca_task_key)
/usr/local/lib/python3.9/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator PCA from version 1.3.2 when using version 1.5.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to: https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
moseq_train.LatentDimension()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | variance_percentage Variance threshold. Fixed value to 90 percent. | latent_dimension Number of principal components required to explain the specified variance. | latent_dim_desc Automated description of the computation result. |
---|---|---|---|---|
1 | 1 | 90.0 | 4 | >=90.0% of variance explained by 4 components. |
Total: 1
To aid the user in selecting the latent dimensions for model fitting, two plots are created below: a cumulative scree plot and a visualization of each Principal Component (PC). In this visualization, translucent nodes/edges represent the mean pose, while opaque nodes/edges represent a perturbation in the direction of the PC. The plots are stored in the output directory.
# Generate and store plots for the user to choose the latent dimensions in the next step
from keypoint_moseq import load_pca, plot_scree, plot_pcs
from element_moseq.readers.kpms_reader import load_kpms_dj_config
from element_moseq.moseq_infer import get_kpms_processed_data_dir
kpms_project_output_dir = (moseq_train.PCATask & pca_task_key).fetch1(
"kpms_project_output_dir"
)
kpms_project_output_dir = get_kpms_processed_data_dir() / kpms_project_output_dir
kpms_dj_config = load_kpms_dj_config(
kpms_project_output_dir.as_posix(), check_if_valid=False, build_indexes=False
)
pca = load_pca(kpms_project_output_dir.as_posix())
# plot_scree(pca, project_dir=kpms_project_output_dir.as_posix())
# plot_pcs(pca, project_dir=kpms_project_output_dir.as_posix(), **kpms_dj_config)
plot_scree(pca, savefig=False)
plot_pcs(pca, savefig=False, **kpms_dj_config)
/usr/local/lib/python3.9/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator PCA from version 1.3.2 when using version 1.5.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to: https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations
The chosen dimension for the next steps in the analysis will be latent dimension = 4
.
Train the AR-HMM and keypoint-SLDS Models¶
The pre-fitting and full-fitting processes for the KPMS Model involve the following steps:
- Initialization: Auto-regressive (AR) parameters and syllable sequences are randomly initialized using pose trajectories from PCA
- Fitting an AR-HMM: AR parameters, transition probabilities and syllable sequences are iteratively updated through Gibbs sampling
- Fitting the full model: All parameters, including both AR-HMM and centroid, heading, noise-estimates, and continuous latent states (i.e., pose trajectories) are iteratively updated through Gibbs sampling. This step is particularly useful for noisy data.
dj.Diagram(moseq_train)
For the pre-fitting step (fitting an AR-HMM), a pre-fitting task needs to be defined and inserted:
moseq_train.PreFitTask()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | pre_latent_dim Latent dimension to use for the model pre-fitting | pre_kappa Kappa value to use for the model pre-fitting | pre_num_iterations Number of Gibbs sampling iterations to run in the model pre-fitting | model_name Name of the model to be loaded if `task_mode='load'` | task_mode 'load': load computed analysis results, 'trigger': trigger computation | pre_fit_desc User-defined description of the pre-fitting task |
---|---|---|---|---|---|---|---|
Total: 0
This task requires the following inputs:
- The keypoint set, body parts, and latent dimension (extracted in the section above).
- A kappa value for the model pre-fitting.
- The number of iterations for the model pre-fitting.
Kappa hyperparameter: An important decision for the user is to adjust the kappa hyperparameter to achieve the desired distribution of syllable durations. Higher values of kappa result in longer syllables.
As a reference, let's choose a kappa value that yields a median syllable duration of 12 frames (400 ms), a duration recommended for rodents.
During the model pre-fitting, it's advisable to explore different values of kappa (kappa_range
) until the syllable durations stabilize.
fps = (moseq_train.PCAPrep & pca_task_key).fetch1("average_frame_rate")
kappa_min = (12 / fps) * 1000 # ms
kappa_max = 1e4 # ms
kappa_range = np.logspace(np.log10(kappa_min), np.log10(kappa_max), num=3)
kappa_range = np.round(kappa_range).astype(int)
print(["kappa = {:.2f} ms".format(x) for x in kappa_range])
['kappa = 400.00 ms', 'kappa = 2000.00 ms', 'kappa = 10000.00 ms']
Number of Iterations: Typically, stabilizing the syllable duration requires 10-50 iterations during the model pre-fitting stage, while stabilizing the syllable sequence after setting kappa may take 200-500 iterations during the model full-fitting stage.
We have already prepared one model with a prefit_key
with pre_latent_dim =4
, pre_kappa=1e6
with task_mode=trigger
.
For tutorial purposes, we will use the task_mode = load
, which will load the pre-fitted model located in the outbox/kpms_project_tutorial
, as follows:
prefit_key = {
**pca_task_key,
"pre_latent_dim": 4,
"pre_kappa": 1000000.0,
"pre_num_iterations": 5,
"pre_fit_desc": "Tutorial PreFit task",
"task_mode": "load",
"model_name": "2024_03_28-18_14_26",
}
prefit_key
{'kpset_id': 1, 'bodyparts_id': 1, 'pre_latent_dim': 4, 'pre_kappa': 1000000.0, 'pre_num_iterations': 5, 'pre_fit_desc': 'Tutorial PreFit task', 'task_mode': 'load', 'model_name': '2024_03_28-18_14_26'}
Thus, we will insert different entries (prefit_keys
) in the PreFitTask
with various kappa values until the target syllable time-scale is achieved.
moseq_train.PreFitTask.heading
# kpset_id : int # Unique ID for each keypoint set bodyparts_id : int # Unique ID for a set of bodyparts for a particular keypoint set pre_latent_dim : int # Latent dimension to use for the model pre-fitting pre_kappa : int # Kappa value to use for the model pre-fitting pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting --- model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode="load" : enum('trigger','load') # 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc="" : varchar(1000) # User-defined description of the pre-fitting task
moseq_train.PreFitTask.insert1(prefit_key, skip_duplicates=True)
Show the contents of the PreFittingTask
table.
moseq_train.PreFitTask()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | pre_latent_dim Latent dimension to use for the model pre-fitting | pre_kappa Kappa value to use for the model pre-fitting | pre_num_iterations Number of Gibbs sampling iterations to run in the model pre-fitting | model_name Name of the model to be loaded if `task_mode='load'` | task_mode 'load': load computed analysis results, 'trigger': trigger computation | pre_fit_desc User-defined description of the pre-fitting task |
---|---|---|---|---|---|---|---|
1 | 1 | 4 | 1000000 | 5 | 2024_03_28-18_14_26 | load | Tutorial PreFit task |
Total: 1
When populating the PreFit
table, the fitting of different AR-HMM models for each kappa defined in the PreFitTask
will be automatically computed. This step will take a few minutes.
moseq_train.PreFit.populate(prefit_key)
moseq_train.PreFit()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | pre_latent_dim Latent dimension to use for the model pre-fitting | pre_kappa Kappa value to use for the model pre-fitting | pre_num_iterations Number of Gibbs sampling iterations to run in the model pre-fitting | model_name Name of the model as "kpms_project_output_dir/model_name" | pre_fit_duration Time duration (seconds) of the model fitting computation |
---|---|---|---|---|---|---|
1 | 1 | 4 | 1000000 | 5 | kpms_project_tutorial/2024_03_28-18_14_26 | nan |
Total: 1
Now we can define a FullFitTask
based on the selected latent_dimension = 4
, the chosen kappa = 10000
based on the previous exploration.
Again and for tutorial purposes, we will load
a model already generated to ensure a smooth run of this notebook.
moseq_train.FullFitTask.heading
# kpset_id : int # Unique ID for each keypoint set bodyparts_id : int # Unique ID for a set of bodyparts for a particular keypoint set full_latent_dim : int # Latent dimension to use for the model full fitting full_kappa : int # Kappa value to use for the model full fitting full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting --- model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode="load" : enum('load','trigger') # Trigger or load the task full_fit_desc="" : varchar(1000) # User-defined description of the model full fitting task
# modify kappa to maintain the desired syllable time-scale
full_fit_key_1 = {
**pca_task_key,
"full_latent_dim": 4,
"full_kappa": 10000.0,
"full_num_iterations": 25,
"full_fit_desc": "Fitting task with kappa = 10000 ms",
"task_mode": "load",
"model_name": "2024_03_28-18_54_08",
}
moseq_train.FullFitTask.insert1(full_fit_key_1, skip_duplicates=True)
Let's add a second FullFitting task:
full_fit_key_2 = {
**pca_task_key,
"full_latent_dim": 4,
"full_kappa": 5000.0,
"full_num_iterations": 25,
"full_fit_desc": "Fitting task with kappa = 5000 ms",
"task_mode": "load",
"model_name": "2024_03_28-18_15_54",
}
moseq_train.FullFitTask.insert1(full_fit_key_2, skip_duplicates=True)
moseq_train.FullFitTask()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | full_latent_dim Latent dimension to use for the model full fitting | full_kappa Kappa value to use for the model full fitting | full_num_iterations Number of Gibbs sampling iterations to run in the model full fitting | model_name Name of the model to be loaded if `task_mode='load'` | task_mode Trigger or load the task | full_fit_desc User-defined description of the model full fitting task |
---|---|---|---|---|---|---|---|
1 | 1 | 4 | 5000 | 25 | 2024_03_28-18_15_54 | load | Fitting task with kappa = 5000 ms |
1 | 1 | 4 | 10000 | 25 | 2024_03_28-18_54_08 | load | Fitting task with kappa = 10000 ms |
Total: 2
moseq_train.FullFit.populate([full_fit_key_1, full_fit_key_2])
moseq_train.FullFit()
kpset_id Unique ID for each keypoint set | bodyparts_id Unique ID for a set of bodyparts for a particular keypoint set | full_latent_dim Latent dimension to use for the model full fitting | full_kappa Kappa value to use for the model full fitting | full_num_iterations Number of Gibbs sampling iterations to run in the model full fitting | model_name Name of the model as "kpms_project_output_dir/model_name" | full_fit_duration Time duration (seconds) of the full fitting computation |
---|---|---|---|---|---|---|
1 | 1 | 4 | 5000 | 25 | kpms_project_tutorial/2024_03_28-18_15_54 | nan |
1 | 1 | 4 | 10000 | 25 | kpms_project_tutorial/2024_03_28-18_54_08 | nan |
Total: 2
Run the inference task and visualize the results¶
The models, along with their relevant information, will be registered in the DataJoint pipeline as follows:
model_name, latent_dim, kappa = (moseq_train.FullFit & "full_kappa = 10000.").fetch1(
"model_name", "full_latent_dim", "full_kappa"
)
moseq_infer.Model.insert1(
{
"model_id": 1,
"model_name": "model 1",
"model_dir": model_name,
"latent_dim": latent_dim,
"kappa": kappa,
},
skip_duplicates=True,
)
model_name, latent_dim, kappa = (moseq_train.FullFit & "full_kappa = 5000.").fetch1(
"model_name", "full_latent_dim", "full_kappa"
)
moseq_infer.Model.insert1(
{
"model_id": 2,
"model_name": "model 2",
"model_dir": model_name,
"latent_dim": latent_dim,
"kappa": kappa,
},
skip_duplicates=True,
)
We can check the Model
table to confirm that the two models have been registered:
moseq_infer.Model()
model_id Unique ID for each model | model_name User-friendly model name | model_dir Model directory relative to root data directory | latent_dim Latent dimension of the model | kappa Kappa value of the model | model_desc Optional. User-defined description of the model |
---|---|---|---|---|---|
1 | model 1 | kpms_project_tutorial/2024_03_28-18_54_08 | 4 | 10000.0 | |
2 | model 2 | kpms_project_tutorial/2024_03_28-18_15_54 | 4 | 5000.0 |
Total: 2
Optional: Model comparison to select a model¶
The expected marginal likelihood (EML) score can be used to rank models. The model with the highest EML score can then be selected for further analysis.
model_names = (moseq_train.FullFit).fetch("model_name")
checkpoint_paths = []
for model_name in model_names:
checkpoint_paths.append(
get_kpms_processed_data_dir() / Path(model_name) / "checkpoint.h5"
)
checkpoint_paths
from keypoint_moseq import expected_marginal_likelihoods, plot_eml_scores
eml_scores, eml_std_errs = expected_marginal_likelihoods(
checkpoint_paths=checkpoint_paths
)
best_model = model_names[np.argmax(eml_scores)]
print(f"Best model: {best_model}")
plot_eml_scores(eml_scores, eml_std_errs, model_names)
100%|█████████████████████████████████████| 2/2 [00:07<00:00, 3.64s/it]
Best model: kpms_project_tutorial/2024_03_28-18_15_54
/home/vscode/.local/lib/python3.9/site-packages/keypoint_moseq/viz.py:2914: UserWarning: Tight layout not applied. The bottom and top margins cannot be made large enough to accommodate all Axes decorations.
(<Figure size 400x350 with 1 Axes>, <Axes: ylabel='EML score'>)
Thus, we choose the best ranked model for the inference task:
best_model_id = (moseq_infer.Model & "model_dir = '{}'".format(best_model)).fetch1(
"model_id"
)
print(f"Best model id: {best_model_id}")
Best model id: 2
For tutorial purposes, we'll utilize the same video set (videos_path
) employed for modeling training as the video set for inference. This will be incorporated into the VideoRecording
table as well.
recording_key = {
**session_key,
"recording_id": 1,
}
moseq_infer.VideoRecording.insert1(
{**recording_key, "device": "Camera1"}, skip_duplicates=True
)
for idx, video_name in enumerate(videos_path):
moseq_infer.VideoRecording.File.insert1(
dict(**recording_key, file_id=idx, file_path=video_name), skip_duplicates=True
)
moseq_infer.VideoRecording * moseq_infer.VideoRecording.File
subject | session_datetime | recording_id Unique ID for each recording | file_id Unique ID for each file | device | file_path Filepath of each video, relative to root data directory. |
---|---|---|---|---|---|
subject1 | 2024-03-15 14:04:22 | 1 | 0 | Camera1 | dlc_project/videos/21_12_10_def6a_3.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 1 | Camera1 | dlc_project/videos/22_04_26_cage4_1_1.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | Camera1 | dlc_project/videos/21_12_10_def6a_1_1.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 3 | Camera1 | dlc_project/videos/22_27_04_cage4_mouse2_0.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 4 | Camera1 | dlc_project/videos/22_04_26_cage4_0.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 5 | Camera1 | dlc_project/videos/21_11_8_one_mouse.top.ir.Mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 6 | Camera1 | dlc_project/videos/21_12_2_def6b_2.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 7 | Camera1 | dlc_project/videos/21_12_10_def6b_3.top.ir.Mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 8 | Camera1 | dlc_project/videos/22_04_26_cage4_0_2.top.ir.mp4 |
subject1 | 2024-03-15 14:04:22 | 1 | 9 | Camera1 | dlc_project/videos/21_12_2_def6a_1.top.ir.mp4 |
Total: 10
The InferenceTask
table serves the purpose of specifying an inference task:
moseq_infer.InferenceTask.heading
# subject : varchar(8) # session_datetime : datetime # recording_id : int # Unique ID for each recording model_id : int # Unique ID for each model --- pose_estimation_method : char(15) # Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) keypointset_dir : varchar(1000) # Keypointset directory for the specified VideoRecording inference_output_dir="" : varchar(1000) # Optional. Sub-directory where the results will be stored inference_desc="" : varchar(1000) # Optional. User-defined description of the inference task num_iterations=null : int # Optional. Number of iterations to use for the model inference. If null, the default number internally is 50. task_mode="load" : enum('load','trigger') # Task mode for the inference task
Defining and inserting a inference task requires:
- Define the subject and session datetime
- Define the video recording
- Define the pose estimation method used for the video recording
- Choose a model
- Specify the output directory and any optional parameters
inference_task = {**recording_key, "model_id": best_model_id}
moseq_infer.InferenceTask.insert1(
{
**inference_task,
"pose_estimation_method": "deeplabcut",
"keypointset_dir": "dlc_project/videos",
"inference_output_dir": "inference_output",
"inference_desc": "Inference task for the tutorial",
"num_iterations": 5, # Limited iterations for tutorial purposes.
},
skip_duplicates=True,
)
moseq_infer.InferenceTask()
subject | session_datetime | recording_id Unique ID for each recording | model_id Unique ID for each model | pose_estimation_method Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) | keypointset_dir Keypointset directory for the specified VideoRecording | inference_output_dir Optional. Sub-directory where the results will be stored | inference_desc Optional. User-defined description of the inference task | num_iterations Optional. Number of iterations to use for the model inference. If null, the default number internally is 50. | task_mode Task mode for the inference task |
---|---|---|---|---|---|---|---|---|---|
subject1 | 2024-03-15 14:04:22 | 1 | 2 | deeplabcut | dlc_project/videos | inference_output | Inference task for the tutorial | 5 | load |
Total: 1
Populating the Inference
table will automatically extract learned states of the model (syllables, latent_state, centroid, and heading) and stored in the inference output directory together with visualizations and grid movies.
moseq_infer.Inference.populate(inference_task)
/usr/local/lib/python3.9/site-packages/sklearn/base.py:376: InconsistentVersionWarning: Trying to unpickle estimator PCA from version 1.3.2 when using version 1.5.1. This might lead to breaking code or invalid results. Use at your own risk. For more info please refer to: https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations Loading keypoints: 100%|████████████████| 10/10 [00:07<00:00, 1.27it/s]
moseq_infer.Inference()
subject | session_datetime | recording_id Unique ID for each recording | model_id Unique ID for each model | inference_duration Time duration (seconds) of the inference computation |
---|---|---|---|---|
subject1 | 2024-03-15 14:04:22 | 1 | 2 | nan |
Total: 1
The MotionSequence
table contains the results for the inference (syllables, latent_state, centroid, and heading):
moseq_infer.Inference.MotionSequence()
subject | session_datetime | recording_id Unique ID for each recording | model_id Unique ID for each model | video_name Name of the video | syllable Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model) | latent_state Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model | centroid Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model | heading Inferred heading (h). The heading of the animal in each frame, as estimated by the model |
---|---|---|---|---|---|---|---|---|
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_11_8_one_mouse.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_12_10_def6a_1_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_12_10_def6a_3.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_12_10_def6b_3.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_12_2_def6a_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 22_04_26_cage4_0_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 22_04_26_cage4_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 22_04_26_cage4_1_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 22_27_04_cage4_mouse2_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000 | =BLOB= | =BLOB= | =BLOB= | =BLOB= |
Total: 10
The GridMoviesSampledInstances
table contains the sampled instances for the grid movies. The sampled instances is a dictionary mapping syllables to lists of instances shown in each grid movie (in row-major order).
moseq_infer.Inference.GridMoviesSampledInstances()
subject | session_datetime | recording_id Unique ID for each recording | model_id Unique ID for each model | syllable Syllable label | instances List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame |
---|---|---|---|---|---|
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 0 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 1 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 2 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 3 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 4 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 5 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 6 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 7 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 8 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 9 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 10 | =BLOB= |
subject1 | 2024-03-15 14:04:22 | 1 | 2 | 11 | =BLOB= |
...
Total: 36
instance_syllable_0 = (
moseq_infer.Inference.GridMoviesSampledInstances & "syllable = 0"
).fetch1("instances")
instance_syllable_0
[('22_04_26_cage4_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 23065, 23074), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 104500, 104518), ('21_11_8_one_mouse.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 7956, 7974), ('21_12_2_def6a_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 70500, 70540), ('22_04_26_cage4_0_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 39581, 39611), ('21_12_2_def6a_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 74072, 74103), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 100604, 100645), ('21_12_10_def6a_1_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 45835, 45853), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 100008, 100022), ('22_04_26_cage4_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 21223, 21232), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 7026, 7053), ('21_12_10_def6b_3.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 35468, 35504), ('21_12_2_def6a_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 63734, 63749), ('22_27_04_cage4_mouse2_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 16209, 16214), ('22_27_04_cage4_mouse2_0.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 39123, 39140), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 19257, 19274), ('21_12_2_def6a_1.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 62477, 62533), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 78366, 78382), ('22_04_26_cage4_0_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 44743, 44770), ('22_04_26_cage4_0_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 23396, 23478), ('21_11_8_one_mouse.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 37546, 37557), ('21_11_8_one_mouse.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 10772, 10777), ('21_12_2_def6b_2.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 41181, 41195), ('21_12_10_def6b_3.top.irDLC_resnet50_moseq_exampleAug21shuffle1_500000', 24808, 24845)]
The instance for syllable 0 is represented as a tuple containing the video name, start frame, and end frame. This format facilitates downstream analysis.
import ipywidgets as widgets
from IPython.display import Image, display
from element_interface.utils import find_full_path
syllable_id = 0
model_dir = (moseq_infer.Model & inference_task).fetch1("model_dir")
inference_output_dir = (
moseq_infer.InferenceTask * moseq_infer.Inference.MotionSequence & inference_task
).fetch("inference_output_dir", limit=1)[0]
model_path = find_full_path(get_kpms_processed_data_dir(), model_dir)
video_path = (
model_path
/ inference_output_dir
/ "grid_movies"
/ ("syllable" + str(syllable_id) + ".mp4")
).as_posix()
print(video_path)
gif_path = (
model_path
/ inference_output_dir
/ "trajectory_plots"
/ ("Syllable" + str(syllable_id) + ".gif")
).as_posix()
gif_path1 = (
model_path
/ inference_output_dir
/ "trajectory_plots"
/ ("Syllable" + str(syllable_id + 1) + ".gif")
).as_posix()
gif_path2 = (
model_path
/ inference_output_dir
/ "trajectory_plots"
/ ("Syllable" + str(syllable_id + 2) + ".gif")
).as_posix()
video_widget = widgets.Video.from_file(video_path, format="mp4", width=640, height=480)
display(video_widget)
display(Image(filename=gif_path))
display(Image(filename=gif_path1))
display(Image(filename=gif_path2))
/workspaces/element-moseq/example_data/outbox/kpms_project_tutorial/2024_03_28-18_15_54/inference_output/grid_movies/syllable0.mp4
Video(value=b'\x00\x00\x00 ftypisom\x00\x00\x02\x00isomiso2avc1mp41\x00\x00\x00\x08free...', height='480', wid…
<IPython.core.display.Image object>
<IPython.core.display.Image object>
<IPython.core.display.Image object>
Summary¶
Following this tutorial, we have:
- Covered the essential functionality of
element-moseq
- Acquired the skills to load the keypoint data and insert metadata into the pipeline
- Learned how to fit a PCA, run the AR-HMM fitting and the Keypoint-SLDS fitting
- Executed and ingested results of the motion sequencing analysis with Keypoint-MoSeq
- Visualized and stored the results
Documentation and DataJoint tutorials¶
- Detailed documentation on
element-moseq
- General
DataJoint-Python
interactive tutorials, covering fundamentals, such as table tiers, query operations, fetch operations, automated computations with the make function, and more. - Documentation for
DataJoint-Python