moseq_train.py
activate(train_schema_name, infer_schema_name=None, *, create_schema=True, create_tables=True, linking_module=None)
¶
Activate this schema.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_schema_name |
str
|
A string containing the name of the |
required |
infer_schema_name |
str
|
A string containing the name of the |
None
|
create_schema |
bool
|
If True (default), schema will be created in the database. |
True
|
create_tables |
bool
|
If True (default), tables related to the schema will be created in the database. |
True
|
linking_module |
str
|
A string containing the module name or module containing the required dependencies to activate the schema. |
None
|
Dependencies: Functions: get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies) with all behavioral recordings, as (list of) string(s). get_kpms_processed_data_dir(): Optional. Returns absolute path for processed data.
Source code in element_moseq/moseq_train.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
|
KeypointSet
¶
Bases: Manual
Store the keypoint data and the video set directory for model training.
Attributes:
Name | Type | Description |
---|---|---|
kpset_id |
int)
|
Unique ID for each keypoint set. |
PoseEstimationMethod |
foreign key)
|
Unique format method used to obtain the keypoints data. |
kpset_dir |
str)
|
Path where the keypoint files are located together with the pose estimation |
kpset_desc |
str)
|
Optional. User-entered description. |
Source code in element_moseq/moseq_train.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
VideoFile
¶
Bases: Part
Store the IDs and file paths of each video file that will be used for model training.
Attributes:
Name | Type | Description |
---|---|---|
KeypointSet |
foreign key)
|
Unique ID for each keypoint set. |
video_id |
int)
|
Unique ID for each video corresponding to each keypoint data file, relative to root data directory. |
video_path |
str)
|
Filepath of each video from which the keypoints are derived, relative to root data directory. |
Source code in element_moseq/moseq_train.py
96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 |
|
Bodyparts
¶
Bases: Manual
Store the body parts to use in the analysis.
Attributes:
Name | Type | Description |
---|---|---|
KeypointSet |
foreign key)
|
Unique ID for each |
bodyparts_id |
int)
|
Unique ID for a set of bodyparts for a particular keypoint set. |
anterior_bodyparts |
blob)
|
List of strings of anterior bodyparts |
posterior_bodyparts |
blob)
|
List of strings of posterior bodyparts |
use_bodyparts |
blob)
|
List of strings of bodyparts to be used |
bodyparts_desc(varchar) |
|
Optional. User-entered description. |
Source code in element_moseq/moseq_train.py
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
|
PCATask
¶
Bases: Manual
Staging table to define the PCA task and its output directory.
Attributes:
Name | Type | Description |
---|---|---|
Bodyparts |
foreign key)
|
Unique ID for each |
kpms_project_output_dir |
str)
|
Keypoint-MoSeq project output directory, relative to root data directory |
Source code in element_moseq/moseq_train.py
137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
|
PCAPrep
¶
Bases: Imported
Table to set up the Keypoint-MoSeq project output directory (kpms_project_output_dir
) , creating the default config.yml
and updating it in a new kpms_dj_config.yml
.
Attributes:
Name | Type | Description |
---|---|---|
PCATask |
foreign key)
|
Unique ID for each |
coordinates |
longblob)
|
Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). |
confidences |
longblob)
|
Dictionary mapping filenames to |
formatted_bodyparts |
longblob)
|
List of bodypart names. The order of the names matches the order of the bodyparts in |
average_frame_rate |
float)
|
Average frame rate of the videos for model training. |
frame_rates |
longblob)
|
List of the frame rates of the videos for model training. |
Source code in element_moseq/moseq_train.py
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
|
make(key)
¶
Make function to:
1. Generate and update the kpms_dj_config.yml
with both the videoset directory and the bodyparts.
2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
dict
|
Primary key from the |
required |
Raises:
Type | Description |
---|---|
NotImplementedError
|
|
High-Level Logic:
1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set.
2. Set variables for each of the full path of the mentioned directories.
3. Find the first existing pose estimation config file in the kpset_dir
directory, if not found, raise an error.
4. Check that the pose_estimation_method is deeplabcut
and set up the project output directory with the default config.yml
.
5. Create the kpms_project_output_dir
(if it does not exist), and generates the kpms default config.yml
with the default values from the pose estimation config.
6. Create a copy of the kpms config.yml
named kpms_dj_config.yml
that will be updated with both the video_dir
and bodyparts
7. Load keypoint data from the keypoint files found in the kpset_dir
that will serve as the training set.
8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling.
9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value.
10. Insert the results of this make
function into the table.
Source code in element_moseq/moseq_train.py
180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
|
PCAFit
¶
Bases: Computed
Fit PCA model.
Attributes:
Name | Type | Description |
---|---|---|
PCAPrep |
foreign key)
|
|
pca_fit_time |
datetime)
|
datetime of the PCA fitting analysis. |
Source code in element_moseq/moseq_train.py
298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
|
make(key)
¶
Make function to format the keypoint data, fit the PCA model, and store it as a pca.p
file in the Keypoint-MoSeq project output directory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
dict
|
|
required |
Raises:
High-Level Logic:
1. Fetch the kpms_project_output_dir
from the PCATask
table and define its full path.
2. Load the kpms_dj_config
file that contains the updated video_dir
and bodyparts, and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting.
3. Fit the PCA model and save it as pca.p
file in the output directory.
4.Insert the creation datetime as the pca_fit_time
into the table.
Source code in element_moseq/moseq_train.py
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 |
|
LatentDimension
¶
Bases: Imported
Determine the latent dimension as part of the autoregressive hyperparameters (ar_hypparams
) for the model fitting.
The objective of the analysis is to inform the user about the number of principal components needed to explain a
90% variance threshold. Subsequently, the decision on how many components to utilize for the model fitting is left
to the user.
Attributes:
Name | Type | Description |
---|---|---|
PCAFit |
foreign key)
|
|
variance_percentage |
float)
|
Variance threshold. Fixed value to 90%. |
latent_dimension |
int)
|
Number of principal components required to explain the specified variance. |
latent_dim_desc |
varchar)
|
Automated description of the computation result. |
Source code in element_moseq/moseq_train.py
356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
|
make(key)
¶
Make function to compute and store the latent dimension that explains a 90% variance threshold.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
dict
|
|
required |
Raises:
High-Level Logic:
1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path.
2. Load the PCA model from file in this directory.
2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio.
3. Determine the number of components required to explain the specified variance.
3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, it sets the latent_dimension
to the total number of components and variance_percentage
to the cumulative sum of the explained variance ratio.
3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, it sets the latent_dimension
to the number of components that explain the specified variance and variance_percentage
to the specified variance threshold.
4. Insert the results of this make
function into the table.
Source code in element_moseq/moseq_train.py
379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 |
|
PreFitTask
¶
Bases: Manual
Insert the parameters for the model (AR-HMM) pre-fitting.
Attributes:
Name | Type | Description |
---|---|---|
PCAFit |
foreign key)
|
|
pre_latent_dim |
int)
|
Latent dimension to use for the model pre-fitting. |
pre_kappa |
int)
|
Kappa value to use for the model pre-fitting. |
pre_num_iterations |
int)
|
Number of Gibbs sampling iterations to run in the model pre-fitting. |
pre_fit_desc(varchar) |
|
User-defined description of the pre-fitting task. |
Source code in element_moseq/moseq_train.py
441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 |
|
PreFit
¶
Bases: Computed
Fit AR-HMM model.
Attributes:
Name | Type | Description |
---|---|---|
PreFitTask |
foreign key)
|
|
model_name |
varchar)
|
Name of the model as "kpms_project_output_dir/model_name". |
pre_fit_duration |
float)
|
Time duration (seconds) of the model fitting computation. |
Source code in element_moseq/moseq_train.py
465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 |
|
make(key)
¶
Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x'].
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
dict)
|
dictionary with the |
required |
Raises:
High-level Logic:
1. Fetch the kpms_project_output_dir
and define the full path.
2. Fetch the model parameters from the PreFitTask
table.
3. Update the dj_config.yml
with the latent dimension and kappa for the AR-HMM fitting.
4. Load the pca model from file in the kpms_project_output_dir
.
5. Fetch coordinates
and confidences
scores to format the data for the model initialization. # Data - contains the data for model fitting. # Metadata - contains the recordings and start/end frames for the data.
6. Initialize the model that create a model
dict containing states, parameters, hyperparameters, noise prior, and random seed.
7. Update the model dict with the selected kappa for the AR-HMM fitting.
8. Fit the AR-HMM model using the pre_num_iterations
and create a subdirectory in kpms_project_output_dir
with the model's latest checkpoint file.
9. Calculate the duration of the model fitting computation and insert it in the PreFit
table.
Source code in element_moseq/moseq_train.py
482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 |
|
FullFitTask
¶
Bases: Manual
Insert the parameters for the full (Keypoint-SLDS model) fitting. The full model will generally require a lower value of kappa to yield the same target syllable durations.
Attributes:
Name | Type | Description |
---|---|---|
PCAFit |
foreign key)
|
|
full_latent_dim |
int)
|
Latent dimension to use for the model full fitting. |
full_kappa |
int)
|
Kappa value to use for the model full fitting. |
full_num_iterations |
int)
|
Number of Gibbs sampling iterations to run in the model full fitting. |
full_fit_desc(varchar) |
|
User-defined description of the model full fitting task. |
Source code in element_moseq/moseq_train.py
586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 |
|
FullFit
¶
Bases: Computed
Fit the full (Keypoint-SLDS) model.
Attributes:
Name | Type | Description |
---|---|---|
FullFitTask |
foreign key)
|
|
model_name |
|
varchar(100) # Name of the model as "kpms_project_output_dir/model_name" |
full_fit_duration |
float)
|
Time duration (seconds) of the full fitting computation |
Source code in element_moseq/moseq_train.py
612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 |
|
make(key)
¶
Make function to fit the full (keypoint-SLDS) model
Parameters:
Name | Type | Description | Default |
---|---|---|---|
key |
dict
|
dictionary with the |
required |
Raises:
High-level Logic:
1. Fetch the kpms_project_output_dir
and define the full path.
2. Fetch the model parameters from the FullFitTask
table.
2. Update the dj_config.yml
with the selected latent dimension and kappa for the full-fitting.
3. Initialize and fit the full model in a new model_name
directory.
4. Load the pca model from file in the kpms_project_output_dir
.
5. Fetch the coordinates
and confidences
scores to format the data for the model initialization.
6. Initialize the model that create a model
dict containing states, parameters, hyperparameters, noise prior, and random seed.
7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting.
8. Fit the Keypoint-SLDS model using the full_num_iterations
and create a subdirectory in kpms_project_output_dir
with the model's latest checkpoint file.
8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. This function permutes the states and parameters of a saved checkpoint so that syllables are labeled in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on).
8. Calculate the duration of the model fitting computation and insert it in the PreFit
table.
Source code in element_moseq/moseq_train.py
629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 |
|