Skip to content

[BUG] Unable to extract session embeddings from a session-based transformer model #163

Description

@rnyak

Bug description

I am trying to extract embedding but the following options do not work.

Option 1:

I tried these scripts but none works:

model_transformer.query_embeddings(train, index='session_id')

or 

model_transformer.query_embeddings(train, batch_size = 1024,  index='session_id')

Option 2:

I am able to generate session embeddings for a single batch but it does not work if I iterate over the loader batch by batch, it crashes.

this works:
model_transformer.query_encoder(batch[0])

but iterating over loader batch by batch does not work:

all_sess_embeddings = []
for batch, _ in iter(loader):
    embds = model_transformer.query_encoder(batch).numpy()
    del batch
    gc.collect()
    all_sess_embeddings.append(embds)

Steps/Code to reproduce bug

Please go to this link to download the gist for the code to repro the issue:

https://gist.github.com/rnyak/d70822084c26ba6972615512e8a78bb2

Expected behavior

We should be able to extract session embeddings from query_model of the transformer model without any issues.

Environment details

  • Merlin version:
  • Platform:
  • Python version:
  • PyTorch version (GPU?):
  • Tensorflow version (GPU?): Using tensorflow 23.06 image with the latest branches pulled.

Metadata

Metadata

Labels

P0bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions