diff --git a/sea-orm-codegen/src/entity/column.rs b/sea-orm-codegen/src/entity/column.rs index 25f650fee7..2188be899d 100644 --- a/sea-orm-codegen/src/entity/column.rs +++ b/sea-orm-codegen/src/entity/column.rs @@ -130,6 +130,33 @@ impl Column { col_type.map(|ty| quote! { column_type = #ty }) } + pub fn get_ts_type_attrs( + &self, + model_extra_derives: &TokenStream, + model_extra_attributes: &TokenStream, + ) -> Option { + if !matches!(self.col_type, ColumnType::Vector(_)) { + return None; + } + + let mut attrs = Vec::new(); + let tokens = format!("{}{}", model_extra_derives, model_extra_attributes) + .replace(|c: char| c.is_whitespace(), ""); + + if tokens.contains("ts_rs::TS") || tokens.contains("ts(export)") { + attrs.push(quote! { #[ts(type = "number[]")] }); + } + if tokens.contains("specta::Type") || tokens.contains("specta(export)") { + attrs.push(quote! { #[specta(type = "number[]")] }); + } + + if attrs.is_empty() { + None + } else { + Some(quote! { #(#attrs)* }) + } + } + pub fn get_def(&self) -> TokenStream { fn write_col_def(col_type: &ColumnType) -> TokenStream { match col_type { @@ -369,9 +396,38 @@ mod tests { make_col!("date_time", ColumnType::DateTime), make_col!("timestamp", ColumnType::Timestamp), make_col!("timestamp_tz", ColumnType::TimestampWithTimeZone), + make_col!("embedding", ColumnType::Vector(None)), ] } + #[test] + fn test_get_ts_type_attrs() { + let col = Column { + name: "embedding".to_owned(), + col_type: ColumnType::Vector(None), + auto_increment: false, + not_null: false, + unique: false, + unique_key: None, + }; + + let ts_attr = col + .get_ts_type_attrs( + "e! { ts_rs::TS }, + &TokenStream::new(), + ) + .expect("Expected ts attribute"); + assert_eq!(ts_attr.to_string(), "# [ts (type = \"number[]\")]"); + + let specta_attr = col + .get_ts_type_attrs( + "e! { specta::Type }, + &TokenStream::new(), + ) + .expect("Expected specta attribute"); + assert_eq!(specta_attr.to_string(), "# [specta (type = \"number[]\")]"); + } + #[test] fn test_get_name_snake_case() { let columns = setup(); diff --git a/sea-orm-codegen/src/entity/writer.rs b/sea-orm-codegen/src/entity/writer.rs index ba1d3ce7e6..63fa4d6ed1 100644 --- a/sea-orm-codegen/src/entity/writer.rs +++ b/sea-orm-codegen/src/entity/writer.rs @@ -2896,6 +2896,55 @@ mod tests { Ok(()) } + #[test] + fn test_gen_with_ts_vector_support() -> io::Result<()> { + let entity = Entity { + table_name: "document".to_owned(), + columns: vec![ + Column { + name: "id".to_owned(), + col_type: ColumnType::Integer, + auto_increment: true, + not_null: true, + unique: false, + unique_key: None, + }, + Column { + name: "embedding".to_owned(), + col_type: ColumnType::Vector(None), + auto_increment: false, + not_null: false, + unique: false, + unique_key: None, + }, + ], + relations: vec![], + conjunct_relations: vec![], + primary_keys: vec![PrimaryKey { + name: "id".to_owned(), + }], + }; + + let generated = generated_to_string(EntityWriter::gen_compact_code_blocks( + &entity, + &WithSerde::None, + &default_column_option(), + &None, + false, + false, + &bonus_derive(["ts_rs::TS"]), + &TokenStream::new(), + &TokenStream::new(), + false, + true, + )); + + assert!(generated.contains("# [ts (type = \"number[]\")]")); + assert!(generated.contains("pub embedding : Option < PgVector >")); + + Ok(()) + } + #[test] fn test_gen_import_active_enum() -> io::Result<()> { let entities = vec![ diff --git a/sea-orm-codegen/src/entity/writer/compact.rs b/sea-orm-codegen/src/entity/writer/compact.rs index 6f9e0c423c..c17d4ecd9d 100644 --- a/sea-orm-codegen/src/entity/writer/compact.rs +++ b/sea-orm-codegen/src/entity/writer/compact.rs @@ -99,6 +99,10 @@ impl EntityWriter { } ts = quote! { #[sea_orm(#ts)] }; } + let ts_type_attribute = col.get_ts_type_attrs( + model_extra_derives, + model_extra_attributes, + ); let serde_attribute = col.get_serde_attribute( is_primary_key, serde_skip_deserializing_primary_key, @@ -106,6 +110,7 @@ impl EntityWriter { ); ts = quote! { #ts + #ts_type_attribute #serde_attribute }; ts diff --git a/sea-orm-codegen/src/entity/writer/dense.rs b/sea-orm-codegen/src/entity/writer/dense.rs index 4338071417..ced98416d7 100644 --- a/sea-orm-codegen/src/entity/writer/dense.rs +++ b/sea-orm-codegen/src/entity/writer/dense.rs @@ -95,6 +95,10 @@ impl EntityWriter { } ts = quote! { #[sea_orm(#ts)] }; } + let ts_type_attribute = col.get_ts_type_attrs( + model_extra_derives, + model_extra_attributes, + ); let serde_attribute = col.get_serde_attribute( is_primary_key, serde_skip_deserializing_primary_key, @@ -102,6 +106,7 @@ impl EntityWriter { ); ts = quote! { #ts + #ts_type_attribute #serde_attribute }; ts diff --git a/sea-orm-codegen/src/entity/writer/expanded.rs b/sea-orm-codegen/src/entity/writer/expanded.rs index 76b81ee843..3d0a0a4bc1 100644 --- a/sea-orm-codegen/src/entity/writer/expanded.rs +++ b/sea-orm-codegen/src/entity/writer/expanded.rs @@ -60,10 +60,26 @@ impl EntityWriter { let column_names_snake_case = entity.get_column_names_snake_case(); let column_rs_types = entity.get_column_rs_types(column_option); let if_eq_needed = entity.get_eq_needed(); - let serde_attributes = entity.get_column_serde_attributes( - serde_skip_deserializing_primary_key, - serde_skip_hidden_column, - ); + let column_attributes: Vec = entity + .columns + .iter() + .map(|col| { + let is_primary_key = entity.primary_keys.iter().any(|pk| pk.name == col.name); + let ts_type_attribute = col.get_ts_type_attrs( + model_extra_derives, + model_extra_attributes, + ); + let serde_attribute = col.get_serde_attribute( + is_primary_key, + serde_skip_deserializing_primary_key, + serde_skip_hidden_column, + ); + quote! { + #ts_type_attribute + #serde_attribute + } + }) + .collect(); let extra_derive = with_serde.extra_derive(); quote! { @@ -71,7 +87,7 @@ impl EntityWriter { #model_extra_attributes pub struct Model { #( - #serde_attributes + #column_attributes pub #column_names_snake_case: #column_rs_types, )* } diff --git a/sea-orm-codegen/src/entity/writer/frontend.rs b/sea-orm-codegen/src/entity/writer/frontend.rs index 6b2fa2507c..ecc066330c 100644 --- a/sea-orm-codegen/src/entity/writer/frontend.rs +++ b/sea-orm-codegen/src/entity/writer/frontend.rs @@ -57,11 +57,19 @@ impl EntityWriter { .iter() .map(|col| { let is_primary_key = primary_keys.contains(&col.name); - col.get_serde_attribute( + let ts_type_attribute = col.get_ts_type_attrs( + model_extra_derives, + model_extra_attributes, + ); + let serde_attribute = col.get_serde_attribute( is_primary_key, serde_skip_deserializing_primary_key, serde_skip_hidden_column, - ) + ); + quote! { + #ts_type_attribute + #serde_attribute + } }) .collect(); let extra_derive = with_serde.extra_derive();